GorNishanov updated this revision to Diff 88152.
GorNishanov added a comment.

Initialized PromiseRecordDecl to nullptr


https://reviews.llvm.org/D28835

Files:
  include/clang/AST/StmtCXX.h
  include/clang/Sema/ScopeInfo.h
  lib/AST/StmtCXX.cpp
  lib/Sema/SemaCoroutine.cpp
  test/SemaCXX/coroutines.cpp

Index: test/SemaCXX/coroutines.cpp
===================================================================
--- test/SemaCXX/coroutines.cpp
+++ test/SemaCXX/coroutines.cpp
@@ -1,4 +1,4 @@
-// RUN: %clang_cc1 -std=c++14 -fcoroutines-ts -verify %s
+// RUN: %clang_cc1 -std=c++14 -fcoroutines-ts -verify %s -fcxx-exceptions
 
 void no_coroutine_traits_bad_arg_await() {
   co_await a; // expected-error {{include <experimental/coroutine>}}
Index: lib/Sema/SemaCoroutine.cpp
===================================================================
--- lib/Sema/SemaCoroutine.cpp
+++ lib/Sema/SemaCoroutine.cpp
@@ -487,7 +487,7 @@
 static bool buildAllocationAndDeallocation(Sema &S, SourceLocation Loc,
                                            FunctionScopeInfo *Fn,
                                            Expr *&Allocation,
-                                           Stmt *&Deallocation) {
+                                           Expr *&Deallocation) {
   TypeSourceInfo *TInfo = Fn->CoroutinePromise->getTypeSourceInfo();
   QualType PromiseType = TInfo->getType();
   if (PromiseType->isDependentType())
@@ -564,6 +564,48 @@
   return true;
 }
 
+namespace {
+class SubStmtBuilder : public CoroutineBodyStmt::CtorArgs {
+  Sema &S;
+  FunctionDecl &FD;
+  FunctionScopeInfo &Fn;
+  bool IsValid;
+  SourceLocation Loc;
+  QualType RetType;
+  SmallVector<Stmt *, 4> ParamMovesVector;
+  const bool IsPromiseDependentType;
+  CXXRecordDecl *PromiseRecordDecl = nullptr;
+
+public:
+  SubStmtBuilder(Sema &S, FunctionDecl &FD, FunctionScopeInfo &Fn, Stmt *Body)
+      : S(S), FD(FD), Fn(Fn), Loc(FD.getLocation()),
+        IsPromiseDependentType(
+            !Fn.CoroutinePromise ||
+            Fn.CoroutinePromise->getType()->isDependentType()) {
+    this->Body = Body;
+    if (!IsPromiseDependentType) {
+      PromiseRecordDecl = Fn.CoroutinePromise->getType()->getAsCXXRecordDecl();
+      assert(PromiseRecordDecl && "Type should have already been checked");
+    }
+    this->IsValid = makePromiseStmt() && makeInitialSuspend() &&
+                    makeFinalSuspend() && makeOnException() &&
+                    makeOnFallthrough() && makeNewAndDeleteExpr() &&
+                    makeReturnObject() && makeParamMoves();
+  }
+
+  bool isInvalid() const { return !this->IsValid; }
+
+  bool makePromiseStmt();
+  bool makeInitialSuspend();
+  bool makeFinalSuspend();
+  bool makeNewAndDeleteExpr();
+  bool makeOnFallthrough();
+  bool makeOnException();
+  bool makeReturnObject();
+  bool makeParamMoves();
+};
+}
+
 void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) {
   FunctionScopeInfo *Fn = getCurFunction();
   assert(Fn && !Fn->CoroutineStmts.empty() && "not a coroutine");
@@ -577,120 +619,159 @@
         << (isa<CoawaitExpr>(First) ? 0 :
             isa<CoyieldExpr>(First) ? 1 : 2);
   }
+  SubStmtBuilder Builder(*this, *FD, *Fn, Body);
+  if (Builder.isInvalid())
+    return FD->setInvalidDecl();
 
-  SourceLocation Loc = FD->getLocation();
+  // Build body for the coroutine wrapper statement.
+  Body = CoroutineBodyStmt::Create(Context, Builder);
+}
 
+bool SubStmtBuilder::makePromiseStmt() {
   // Form a declaration statement for the promise declaration, so that AST
   // visitors can more easily find it.
   StmtResult PromiseStmt =
-      ActOnDeclStmt(ConvertDeclToDeclGroup(Fn->CoroutinePromise), Loc, Loc);
+      S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(Fn.CoroutinePromise), Loc, Loc);
   if (PromiseStmt.isInvalid())
-    return FD->setInvalidDecl();
+    return false;
+
+  this->Promise = PromiseStmt.get();
+  return true;
+}
 
+bool SubStmtBuilder::makeInitialSuspend() {
   // Form and check implicit 'co_await p.initial_suspend();' statement.
   ExprResult InitialSuspend =
-      buildPromiseCall(*this, Fn, Loc, "initial_suspend", None);
+      buildPromiseCall(S, &Fn, Loc, "initial_suspend", None);
   // FIXME: Support operator co_await here.
   if (!InitialSuspend.isInvalid())
-    InitialSuspend = BuildCoawaitExpr(Loc, InitialSuspend.get());
-  InitialSuspend = ActOnFinishFullExpr(InitialSuspend.get());
+    InitialSuspend = S.BuildCoawaitExpr(Loc, InitialSuspend.get());
+  InitialSuspend = S.ActOnFinishFullExpr(InitialSuspend.get());
   if (InitialSuspend.isInvalid())
-    return FD->setInvalidDecl();
+    return false;
+
+  this->InitialSuspend = InitialSuspend.get();
+  return true;
+}
 
+bool SubStmtBuilder::makeFinalSuspend() {
   // Form and check implicit 'co_await p.final_suspend();' statement.
   ExprResult FinalSuspend =
-      buildPromiseCall(*this, Fn, Loc, "final_suspend", None);
+      buildPromiseCall(S, &Fn, Loc, "final_suspend", None);
   // FIXME: Support operator co_await here.
   if (!FinalSuspend.isInvalid())
-    FinalSuspend = BuildCoawaitExpr(Loc, FinalSuspend.get());
-  FinalSuspend = ActOnFinishFullExpr(FinalSuspend.get());
+    FinalSuspend = S.BuildCoawaitExpr(Loc, FinalSuspend.get());
+  FinalSuspend = S.ActOnFinishFullExpr(FinalSuspend.get());
   if (FinalSuspend.isInvalid())
-    return FD->setInvalidDecl();
+    return false;
 
+  this->FinalSuspend = FinalSuspend.get();
+  return true;
+}
+
+bool SubStmtBuilder::makeNewAndDeleteExpr() {
   // Form and check allocation and deallocation calls.
-  Expr *Allocation = nullptr;
-  Stmt *Deallocation = nullptr;
-  if (!buildAllocationAndDeallocation(*this, Loc, Fn, Allocation, Deallocation))
-    return FD->setInvalidDecl();
+  return buildAllocationAndDeallocation(S, Loc, &Fn, this->Allocate,
+                                        this->Deallocate);
+}
+
+bool SubStmtBuilder::makeOnFallthrough() {
+  if (!PromiseRecordDecl)
+    return true;
+
+  // [dcl.fct.def.coroutine]/4
+  // The unqualified-ids 'return_void' and 'return_value' are looked up in
+  // the scope of class P. If both are found, the program is ill-formed.
+  DeclarationName RVoidDN = S.PP.getIdentifierInfo("return_void");
+  LookupResult RVoidResult(S, RVoidDN, Loc, Sema::LookupMemberName);
+  const bool HasRVoid = S.LookupQualifiedName(RVoidResult, PromiseRecordDecl);
 
-  // control flowing off the end of the coroutine.
-  // Also try to form 'p.set_exception(std::current_exception());' to handle
+  DeclarationName RValueDN = S.PP.getIdentifierInfo("return_value");
+  LookupResult RValueResult(S, RValueDN, Loc, Sema::LookupMemberName);
+  const bool HasRValue = S.LookupQualifiedName(RValueResult, PromiseRecordDecl);
+
+  StmtResult Fallthrough;
+  if (HasRVoid && HasRValue) {
+    // FIXME Improve this diagnostic
+    S.Diag(FD.getLocation(), diag::err_coroutine_promise_return_ill_formed)
+        << PromiseRecordDecl;
+    return false;
+  } else if (HasRVoid) {
+    // If the unqualified-id return_void is found, flowing off the end of a
+    // coroutine is equivalent to a co_return with no operand. Otherwise,
+    // flowing off the end of a coroutine results in undefined behavior.
+    Fallthrough = S.BuildCoreturnStmt(FD.getLocation(), nullptr);
+    Fallthrough = S.ActOnFinishFullStmt(Fallthrough.get());
+    if (Fallthrough.isInvalid())
+      return false;
+  }
+
+  this->OnFallthrough = Fallthrough.get();
+  return true;
+}
+
+bool SubStmtBuilder::makeOnException() {
+  // Try to form 'p.set_exception(std::current_exception());' to handle
   // uncaught exceptions.
+  // TODO: Post WG21 Issaquah 2016 renamed set_exception to unhandled_exception
+  // TODO: and dropped exception_ptr parameter. Make it so.
+
+  if (!PromiseRecordDecl)
+    return true;
+
+  // If exceptions are disabled, don't try to build OnException.
+  if (!S.getLangOpts().CXXExceptions)
+    return true;
+
   ExprResult SetException;
-  StmtResult Fallthrough;
-  if (Fn->CoroutinePromise &&
-      !Fn->CoroutinePromise->getType()->isDependentType()) {
-    CXXRecordDecl *RD = Fn->CoroutinePromise->getType()->getAsCXXRecordDecl();
-    assert(RD && "Type should have already been checked");
-    // [dcl.fct.def.coroutine]/4
-    // The unqualified-ids 'return_void' and 'return_value' are looked up in
-    // the scope of class P. If both are found, the program is ill-formed.
-    DeclarationName RVoidDN = PP.getIdentifierInfo("return_void");
-    LookupResult RVoidResult(*this, RVoidDN, Loc, Sema::LookupMemberName);
-    const bool HasRVoid = LookupQualifiedName(RVoidResult, RD);
-
-    DeclarationName RValueDN = PP.getIdentifierInfo("return_value");
-    LookupResult RValueResult(*this, RValueDN, Loc, Sema::LookupMemberName);
-    const bool HasRValue = LookupQualifiedName(RValueResult, RD);
-
-    if (HasRVoid && HasRValue) {
-      // FIXME Improve this diagnostic
-      Diag(FD->getLocation(), diag::err_coroutine_promise_return_ill_formed)
-          << RD;
-      return FD->setInvalidDecl();
-    } else if (HasRVoid) {
-      // If the unqualified-id return_void is found, flowing off the end of a
-      // coroutine is equivalent to a co_return with no operand. Otherwise,
-      // flowing off the end of a coroutine results in undefined behavior.
-      Fallthrough = BuildCoreturnStmt(FD->getLocation(), nullptr);
-      Fallthrough = ActOnFinishFullStmt(Fallthrough.get());
-      if (Fallthrough.isInvalid())
-        return FD->setInvalidDecl();
-    }
 
-    // [dcl.fct.def.coroutine]/3
-    // The unqualified-id set_exception is found in the scope of P by class
-    // member access lookup (3.4.5).
-    DeclarationName SetExDN = PP.getIdentifierInfo("set_exception");
-    LookupResult SetExResult(*this, SetExDN, Loc, Sema::LookupMemberName);
-    if (LookupQualifiedName(SetExResult, RD)) {
-      // Form the call 'p.set_exception(std::current_exception())'
-      SetException = buildStdCurrentExceptionCall(*this, Loc);
-      if (SetException.isInvalid())
-        return FD->setInvalidDecl();
-      Expr *E = SetException.get();
-      SetException = buildPromiseCall(*this, Fn, Loc, "set_exception", E);
-      SetException = ActOnFinishFullExpr(SetException.get(), Loc);
-      if (SetException.isInvalid())
-        return FD->setInvalidDecl();
-    }
+  // [dcl.fct.def.coroutine]/3
+  // The unqualified-id set_exception is found in the scope of P by class
+  // member access lookup (3.4.5).
+  DeclarationName SetExDN = S.PP.getIdentifierInfo("set_exception");
+  LookupResult SetExResult(S, SetExDN, Loc, Sema::LookupMemberName);
+  if (S.LookupQualifiedName(SetExResult, PromiseRecordDecl)) {
+    // Form the call 'p.set_exception(std::current_exception())'
+    SetException = buildStdCurrentExceptionCall(S, Loc);
+    if (SetException.isInvalid())
+      return false;
+    Expr *E = SetException.get();
+    SetException = buildPromiseCall(S, &Fn, Loc, "set_exception", E);
+    SetException = S.ActOnFinishFullExpr(SetException.get(), Loc);
+    if (SetException.isInvalid())
+      return false;
   }
 
+  this->OnException = SetException.get();
+  return true;
+}
+
+bool SubStmtBuilder::makeReturnObject() {
+
   // Build implicit 'p.get_return_object()' expression and form initialization
   // of return type from it.
   ExprResult ReturnObject =
-      buildPromiseCall(*this, Fn, Loc, "get_return_object", None);
+      buildPromiseCall(S, &Fn, Loc, "get_return_object", None);
   if (ReturnObject.isInvalid())
-    return FD->setInvalidDecl();
-  QualType RetType = FD->getReturnType();
+    return false;
+  QualType RetType = FD.getReturnType();
   if (!RetType->isDependentType()) {
     InitializedEntity Entity =
         InitializedEntity::InitializeResult(Loc, RetType, false);
-    ReturnObject = PerformMoveOrCopyInitialization(Entity, nullptr, RetType,
+    ReturnObject = S.PerformMoveOrCopyInitialization(Entity, nullptr, RetType,
                                                    ReturnObject.get());
     if (ReturnObject.isInvalid())
-      return FD->setInvalidDecl();
+      return false;
   }
-  ReturnObject = ActOnFinishFullExpr(ReturnObject.get(), Loc);
+  ReturnObject = S.ActOnFinishFullExpr(ReturnObject.get(), Loc);
   if (ReturnObject.isInvalid())
-    return FD->setInvalidDecl();
+    return false;
 
-  // FIXME: Perform move-initialization of parameters into frame-local copies.
-  SmallVector<Expr*, 16> ParamMoves;
+  this->ReturnValue = ReturnObject.get();
+  return true;
+}
 
-  // Build body for the coroutine wrapper statement.
-  Body = new (Context) CoroutineBodyStmt(
-      Body, PromiseStmt.get(), InitialSuspend.get(), FinalSuspend.get(),
-      SetException.get(), Fallthrough.get(), Allocation, Deallocation,
-      ReturnObject.get(), ParamMoves);
+bool SubStmtBuilder::makeParamMoves() {
+  // FIXME: Perform move-initialization of parameters into frame-local copies.
+  return true;
 }
Index: lib/AST/StmtCXX.cpp
===================================================================
--- lib/AST/StmtCXX.cpp
+++ lib/AST/StmtCXX.cpp
@@ -86,3 +86,28 @@
 const VarDecl *CXXForRangeStmt::getLoopVariable() const {
   return const_cast<CXXForRangeStmt *>(this)->getLoopVariable();
 }
+
+CoroutineBodyStmt *CoroutineBodyStmt::Create(
+    const ASTContext &C, CoroutineBodyStmt::CtorArgs const& Args) {
+  std::size_t Size = totalSizeToAlloc<Stmt *>(
+      CoroutineBodyStmt::FirstParamMove + Args.ParamMoves.size());
+
+  void *Mem = C.Allocate(Size, alignof(CoroutineBodyStmt));
+  return new (Mem) CoroutineBodyStmt(Args);
+}
+
+CoroutineBodyStmt::CoroutineBodyStmt(CoroutineBodyStmt::CtorArgs const &Args)
+    : Stmt(CoroutineBodyStmtClass), NumParams(Args.ParamMoves.size()) {
+  Stmt **SubStmts = getStoredStmts();
+  SubStmts[CoroutineBodyStmt::Body] = Args.Body;
+  SubStmts[CoroutineBodyStmt::Promise] = Args.Promise;
+  SubStmts[CoroutineBodyStmt::InitSuspend] = Args.InitialSuspend;
+  SubStmts[CoroutineBodyStmt::FinalSuspend] = Args.FinalSuspend;
+  SubStmts[CoroutineBodyStmt::OnException] = Args.OnException;
+  SubStmts[CoroutineBodyStmt::OnFallthrough] = Args.OnFallthrough;
+  SubStmts[CoroutineBodyStmt::Allocate] = Args.Allocate;
+  SubStmts[CoroutineBodyStmt::Deallocate] = Args.Deallocate;
+  SubStmts[CoroutineBodyStmt::ReturnValue] = Args.ReturnValue;
+  std::copy(Args.ParamMoves.begin(), Args.ParamMoves.end(),
+            const_cast<Stmt **>(getParamMoves().data()));
+}
\ No newline at end of file
Index: include/clang/Sema/ScopeInfo.h
===================================================================
--- include/clang/Sema/ScopeInfo.h
+++ include/clang/Sema/ScopeInfo.h
@@ -157,7 +157,7 @@
   SmallVector<ReturnStmt*, 4> Returns;
 
   /// \brief The promise object for this coroutine, if any.
-  VarDecl *CoroutinePromise;
+  VarDecl *CoroutinePromise = nullptr;
 
   /// \brief The list of coroutine control flow constructs (co_await, co_yield,
   /// co_return) that occur within the function or block. Empty if and only if
Index: include/clang/AST/StmtCXX.h
===================================================================
--- include/clang/AST/StmtCXX.h
+++ include/clang/AST/StmtCXX.h
@@ -296,7 +296,9 @@
 /// \brief Represents the body of a coroutine. This wraps the normal function
 /// body and holds the additional semantic context required to set up and tear
 /// down the coroutine frame.
-class CoroutineBodyStmt : public Stmt {
+class CoroutineBodyStmt final
+    : public Stmt,
+      private llvm::TrailingObjects<CoroutineBodyStmt, Stmt *> {
   enum SubStmt {
     Body,          ///< The body of the coroutine.
     Promise,       ///< The promise statement.
@@ -309,52 +311,76 @@
     ReturnValue,   ///< Return value for thunk function.
     FirstParamMove ///< First offset for move construction of parameter copies.
   };
-  Stmt *SubStmts[SubStmt::FirstParamMove];
+  unsigned NumParams;
 
   friend class ASTStmtReader;
+  friend TrailingObjects;
+
+  Stmt **getStoredStmts() { return getTrailingObjects<Stmt *>(); }
+
+  Stmt *const *getStoredStmts() const { return getTrailingObjects<Stmt *>(); }
+
 public:
-  CoroutineBodyStmt(Stmt *Body, Stmt *Promise, Stmt *InitSuspend,
-                    Stmt *FinalSuspend, Stmt *OnException, Stmt *OnFallthrough,
-                    Expr *Allocate, Stmt *Deallocate,
-                    Expr *ReturnValue, ArrayRef<Expr *> ParamMoves)
-      : Stmt(CoroutineBodyStmtClass) {
-    SubStmts[CoroutineBodyStmt::Body] = Body;
-    SubStmts[CoroutineBodyStmt::Promise] = Promise;
-    SubStmts[CoroutineBodyStmt::InitSuspend] = InitSuspend;
-    SubStmts[CoroutineBodyStmt::FinalSuspend] = FinalSuspend;
-    SubStmts[CoroutineBodyStmt::OnException] = OnException;
-    SubStmts[CoroutineBodyStmt::OnFallthrough] = OnFallthrough;
-    SubStmts[CoroutineBodyStmt::Allocate] = Allocate;
-    SubStmts[CoroutineBodyStmt::Deallocate] = Deallocate;
-    SubStmts[CoroutineBodyStmt::ReturnValue] = ReturnValue;
-    // FIXME: Tail-allocate space for parameter move expressions and store them.
-    assert(ParamMoves.empty() && "not implemented yet");
-  }
+
+  struct CtorArgs {
+    Stmt *Body = nullptr;
+    Stmt *Promise = nullptr;
+    Expr *InitialSuspend = nullptr;
+    Expr *FinalSuspend = nullptr;
+    Stmt *OnException = nullptr;
+    Stmt *OnFallthrough = nullptr;
+    Expr *Allocate = nullptr;
+    Expr *Deallocate = nullptr;
+    Stmt *ReturnValue = nullptr;
+    ArrayRef<Stmt *> ParamMoves;
+  };
+
+private:
+
+  CoroutineBodyStmt(CtorArgs const& Args);
+
+public:
+  static CoroutineBodyStmt *Create(const ASTContext &C, CtorArgs const &Args);
 
   /// \brief Retrieve the body of the coroutine as written. This will be either
   /// a CompoundStmt or a TryStmt.
   Stmt *getBody() const {
-    return SubStmts[SubStmt::Body];
+    return getStoredStmts()[SubStmt::Body];
   }
 
-  Stmt *getPromiseDeclStmt() const { return SubStmts[SubStmt::Promise]; }
+  Stmt *getPromiseDeclStmt() const {
+    return getStoredStmts()[SubStmt::Promise];
+  }
   VarDecl *getPromiseDecl() const {
     return cast<VarDecl>(cast<DeclStmt>(getPromiseDeclStmt())->getSingleDecl());
   }
 
-  Stmt *getInitSuspendStmt() const { return SubStmts[SubStmt::InitSuspend]; }
-  Stmt *getFinalSuspendStmt() const { return SubStmts[SubStmt::FinalSuspend]; }
+  Stmt *getInitSuspendStmt() const {
+    return getStoredStmts()[SubStmt::InitSuspend];
+  }
+  Stmt *getFinalSuspendStmt() const {
+    return getStoredStmts()[SubStmt::FinalSuspend];
+  }
 
-  Stmt *getExceptionHandler() const { return SubStmts[SubStmt::OnException]; }
+  Stmt *getExceptionHandler() const {
+    return getStoredStmts()[SubStmt::OnException];
+  }
   Stmt *getFallthroughHandler() const {
-    return SubStmts[SubStmt::OnFallthrough];
+    return getStoredStmts()[SubStmt::OnFallthrough];
   }
 
-  Expr *getAllocate() const { return cast<Expr>(SubStmts[SubStmt::Allocate]); }
-  Stmt *getDeallocate() const { return SubStmts[SubStmt::Deallocate]; }
+  Expr *getAllocate() const {
+    return cast<Expr>(getStoredStmts()[SubStmt::Allocate]);
+  }
+  Expr *getDeallocate() const {
+    return cast<Expr>(getStoredStmts()[SubStmt::Deallocate]);
+  }
 
   Expr *getReturnValueInit() const {
-    return cast<Expr>(SubStmts[SubStmt::ReturnValue]);
+    return cast<Expr>(getStoredStmts()[SubStmt::ReturnValue]);
+  }
+  ArrayRef<Stmt const *> getParamMoves() const {
+    return {getStoredStmts() + SubStmt::FirstParamMove, NumParams};
   }
 
   SourceLocation getLocStart() const LLVM_READONLY {
@@ -365,7 +391,8 @@
   }
 
   child_range children() {
-    return child_range(SubStmts, SubStmts + SubStmt::FirstParamMove);
+    return child_range(getStoredStmts(),
+                       getStoredStmts() + SubStmt::FirstParamMove + NumParams);
   }
 
   static bool classof(const Stmt *T) {
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to