GorNishanov created this revision.
GorNishanov added reviewers: rsmith, EricWF.
GorNishanov added a subscriber: cfe-commits.
Herald added a subscriber: mehdi_amini.

1. Sema: Add allocation / deallocation substatements.
2. Sema: Add labels to final-suspend and deallocation substatements.
3. Sema: Allow co_return in a coroutine all by itself
4. CG: Emit allocation and deallocation + test


https://reviews.llvm.org/D25258

Files:
  include/clang/AST/StmtCXX.h
  include/clang/Basic/DiagnosticSemaKinds.td
  lib/CodeGen/CGCoroutine.cpp
  lib/CodeGen/CGStmt.cpp
  lib/CodeGen/CodeGenFunction.h
  lib/Sema/SemaCoroutine.cpp
  test/CodeGenCoroutines/coro-alloc.cpp
  test/SemaCXX/coroutines.cpp

Index: test/SemaCXX/coroutines.cpp
===================================================================
--- test/SemaCXX/coroutines.cpp
+++ test/SemaCXX/coroutines.cpp
@@ -143,13 +143,12 @@
 }
 
 void only_coreturn() {
-  co_return; // expected-warning {{'co_return' used in a function that uses neither 'co_await' nor 'co_yield'}}
+  co_return; // OK
 }
 
 void mixed_coreturn(bool b) {
   if (b)
-    // expected-warning@+1 {{'co_return' used in a function that uses neither}}
-    co_return; // expected-note {{use of 'co_return'}}
+    co_return; // expected-note {{use of 'co_return' here}}
   else
     return; // expected-error {{not allowed in coroutine}}
 }
Index: test/CodeGenCoroutines/coro-alloc.cpp
===================================================================
--- /dev/null
+++ test/CodeGenCoroutines/coro-alloc.cpp
@@ -0,0 +1,118 @@
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fcoroutines-ts -std=c++14 -emit-llvm %s -o - -disable-llvm-passes | FileCheck %s
+
+namespace std {
+namespace experimental {
+template <typename... T>
+struct coroutine_traits; // expected-note {{declared here}}
+}
+}
+
+struct suspend_always {
+  bool await_ready() { return false; }
+  void await_suspend() {}
+  void await_resume() {}
+};
+
+struct global_new_delete_tag {};
+
+template<>
+struct std::experimental::coroutine_traits<void, global_new_delete_tag> {
+  struct promise_type {
+    void get_return_object() {}
+    suspend_always initial_suspend() { return {}; }
+    suspend_always final_suspend() { return {}; }
+    void return_void() {}
+  };
+};
+
+// CHECK-LABEL: f0( 
+extern "C" void f0(global_new_delete_tag) {
+  // CHECK: %[[ID:.+]] = call token @llvm.coro.id(
+  // CHECK: %[[SIZE:.+]] = call i64 @llvm.coro.size.i64()
+  // CHECK: call i8* @_Znwm(i64 %[[SIZE]])
+
+  // CHECK: coro.destroy.label:
+  // CHECK: %[[FRAME:.+]] = call i8* @llvm.coro.frame()
+  // CHECK: %[[MEM:.+]] = call i8* @llvm.coro.free(token %[[ID]], i8* %[[FRAME]])
+  // CHECK: call void @_ZdlPv(i8* %[[MEM]])
+  co_return;
+}
+
+struct promise_new_tag {};
+
+template<>
+struct std::experimental::coroutine_traits<void, promise_new_tag> {
+  struct promise_type {
+    void *operator new(unsigned long);
+    void get_return_object() {}
+    suspend_always initial_suspend() { return {}; }
+    suspend_always final_suspend() { return {}; }
+    void return_void() {}
+  };
+};
+
+// CHECK-LABEL: f1( 
+extern "C" void f1(promise_new_tag ) {
+  // CHECK: %[[ID:.+]] = call token @llvm.coro.id(
+  // CHECK: %[[SIZE:.+]] = call i64 @llvm.coro.size.i64()
+  // CHECK: call i8* @_ZNSt12experimental16coroutine_traitsIJv15promise_new_tagEE12promise_typenwEm(i64 %[[SIZE]])
+
+  // CHECK: coro.destroy.label:
+  // CHECK: %[[FRAME:.+]] = call i8* @llvm.coro.frame()
+  // CHECK: %[[MEM:.+]] = call i8* @llvm.coro.free(token %[[ID]], i8* %[[FRAME]])
+  // CHECK: call void @_ZdlPv(i8* %[[MEM]])
+  co_return;
+}
+
+struct promise_delete_tag {};
+
+template<>
+struct std::experimental::coroutine_traits<void, promise_delete_tag> {
+  struct promise_type {
+    void operator delete(void*);
+    void get_return_object() {}
+    suspend_always initial_suspend() { return {}; }
+    suspend_always final_suspend() { return {}; }
+    void return_void() {}
+  };
+};
+
+// CHECK-LABEL: f2( 
+extern "C" void f2(promise_delete_tag) {
+  // CHECK: %[[ID:.+]] = call token @llvm.coro.id(
+  // CHECK: %[[SIZE:.+]] = call i64 @llvm.coro.size.i64()
+  // CHECK: call i8* @_Znwm(i64 %[[SIZE]])
+
+  // CHECK: coro.destroy.label:
+  // CHECK: %[[FRAME:.+]] = call i8* @llvm.coro.frame()
+  // CHECK: %[[MEM:.+]] = call i8* @llvm.coro.free(token %[[ID]], i8* %[[FRAME]])
+  // CHECK: call void @_ZNSt12experimental16coroutine_traitsIJv18promise_delete_tagEE12promise_typedlEPv(i8* %[[MEM]])
+  co_return;
+}
+
+struct promise_sized_delete_tag {};
+
+template<>
+struct std::experimental::coroutine_traits<void, promise_sized_delete_tag> {
+  struct promise_type {
+    void operator delete(void*, unsigned long);
+    void get_return_object() {}
+    suspend_always initial_suspend() { return {}; }
+    suspend_always final_suspend() { return {}; }
+    void return_void() {}
+  };
+};
+
+// CHECK-LABEL: f3( 
+extern "C" void f3(promise_sized_delete_tag) {
+  // CHECK: %[[ID:.+]] = call token @llvm.coro.id(
+  // CHECK: %[[SIZE:.+]] = call i64 @llvm.coro.size.i64()
+  // CHECK: call i8* @_Znwm(i64 %[[SIZE]])
+
+  // CHECK: coro.destroy.label:
+  // CHECK: %[[FRAME:.+]] = call i8* @llvm.coro.frame()
+  // CHECK: %[[MEM:.+]] = call i8* @llvm.coro.free(token %[[ID]], i8* %[[FRAME]])
+  // CHECK: %[[SIZE2:.+]] = call i64 @llvm.coro.size.i64()
+  // CHECK: call void @_ZNSt12experimental16coroutine_traitsIJv24promise_sized_delete_tagEE12promise_typedlEPvm(i8* %[[MEM]], i64 %[[SIZE2]])
+  co_return;
+}
Index: lib/Sema/SemaCoroutine.cpp
===================================================================
--- lib/Sema/SemaCoroutine.cpp
+++ lib/Sema/SemaCoroutine.cpp
@@ -378,6 +378,143 @@
   return Res;
 }
 
+static Expr *buildBuiltinCall(Sema &S, SourceLocation Loc, Builtin::ID id,
+                              MutableArrayRef<Expr *> CallArgs) {
+  StringRef Name = S.Context.BuiltinInfo.getName(id);
+  LookupResult R(S, &S.Context.Idents.get(Name), Loc, Sema::LookupOrdinaryName);
+  S.LookupName(R, S.TUScope, true);
+
+  FunctionDecl *BuiltInDecl = R.getAsSingle<FunctionDecl>();
+  assert(BuiltInDecl && "failed to find builtin declaration");
+
+  ExprResult DeclRef = S.BuildDeclRefExpr(BuiltInDecl, BuiltInDecl->getType(),
+                                          VK_RValue, Loc, nullptr);
+  assert(DeclRef.isUsable() && "Builtin reference cannot fail");
+
+  ExprResult Call =
+      S.ActOnCallExpr(/*Scope=*/nullptr, DeclRef.get(), Loc, CallArgs, Loc);
+
+  assert(!Call.isInvalid() && "Call to builtin cannot fail!");
+  return Call.get();
+}
+
+// Find an appropriate delete for the promise.
+static FunctionDecl *findDeleteForPromise(Sema &S, SourceLocation Loc,
+                                          QualType PromiseType) {
+  FunctionDecl *OperatorDelete = nullptr;
+
+  DeclarationName DeleteName =
+      S.Context.DeclarationNames.getCXXOperatorName(OO_Delete);
+
+  CXXRecordDecl *PointeeRD = PromiseType->getAsCXXRecordDecl();
+  assert(PointeeRD && "PromiseType must be a CxxRecordDecl type");
+
+  if (S.FindDeallocationFunction(Loc, PointeeRD, DeleteName, OperatorDelete))
+    return nullptr;
+
+  if (!OperatorDelete) {
+    // Look for a global declaration.
+    OperatorDelete = S.FindUsualDeallocationFunction(
+        Loc, S.isCompleteType(Loc, PromiseType), DeleteName);
+
+    S.MarkFunctionReferenced(Loc, OperatorDelete);
+  }
+  return OperatorDelete;
+}
+
+// Builds allocation and deallocation for the coroutine. Returns false on
+// failure.
+static bool buildAllocationAndDeallocation(Sema &S, SourceLocation Loc,
+                                           FunctionScopeInfo *Fn,
+                                           Expr *&Allocation,
+                                           LabelStmt *&Deallocation) {
+  TypeSourceInfo *TInfo = Fn->CoroutinePromise->getTypeSourceInfo();
+  QualType PromiseType = TInfo->getType();
+  if (PromiseType->isDependentType())
+    return true;
+
+  if (S.RequireCompleteType(Loc, PromiseType, diag::err_incomplete_type))
+    return false;
+
+  // FIXME: Add support for get_return_object_on_allocation failure.
+  // FIXME: Add support for stateful allocators.
+
+  FunctionDecl *OperatorNew = nullptr;
+  FunctionDecl *OperatorDelete = nullptr;
+  FunctionDecl *UnusedResult = nullptr;
+
+  S.FindAllocationFunctions(Loc, SourceRange(),
+                            /*UseGlobal*/ false, PromiseType,
+                            /*isArray*/ false, /*PlacementArgs*/ None,
+                            OperatorNew, UnusedResult);
+
+  OperatorDelete = findDeleteForPromise(S, Loc, PromiseType);
+
+  if (!OperatorDelete || !OperatorNew)
+    return false;
+
+  Expr *FramePtr =
+      buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_frame, {});
+
+  Expr *FrameSize =
+      buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_size, {});
+
+  // Make new call.
+
+  ExprResult NewRef =
+      S.BuildDeclRefExpr(OperatorNew, OperatorNew->getType(), VK_LValue, Loc);
+  if (NewRef.isInvalid())
+    return false;
+
+  ExprResult NewExpr = S.ActOnCallExpr(S.getCurScope(), NewRef.get(), Loc,
+                                       FrameSize, Loc, nullptr);
+  if (NewExpr.isInvalid())
+    return false;
+
+  Allocation = NewExpr.get();
+
+  // Make delete call.
+
+  QualType opDeleteQualType = OperatorDelete->getType();
+
+  ExprResult DeleteRef =
+      S.BuildDeclRefExpr(OperatorDelete, opDeleteQualType, VK_LValue, Loc);
+  if (DeleteRef.isInvalid())
+    return false;
+
+  Expr *CoroFree =
+      buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_free, {FramePtr});
+
+  SmallVector<Expr *, 2> DeleteArgs{CoroFree};
+
+  // Check if we need to pass the size.
+  const FunctionProtoType *opDeleteType =
+      opDeleteQualType.getTypePtr()->getAs<FunctionProtoType>();
+  if (opDeleteType->getNumParams() > 1) {
+    DeleteArgs.push_back(FrameSize);
+  }
+
+  ExprResult DeleteExpr = S.ActOnCallExpr(S.getCurScope(), DeleteRef.get(), Loc,
+                                          DeleteArgs, Loc, nullptr);
+  if (DeleteExpr.isInvalid())
+    return false;
+
+  // Make it a labeled statement. Suspend point emission uses this label as a
+  // jump target for the cleanup branch.
+  LabelDecl *DestroyLabel =
+      LabelDecl::Create(S.Context, S.CurContext, SourceLocation(),
+                        S.PP.getIdentifierInfo("coro.destroy.label"));
+
+  StmtResult Stmt = S.ActOnLabelStmt(Loc, DestroyLabel, Loc, DeleteExpr.get());
+
+  if (Stmt.isInvalid())
+    return false;
+
+  Deallocation = cast<LabelStmt>(Stmt.get());
+
+  return true;
+}
+
 void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) {
   FunctionScopeInfo *Fn = getCurFunction();
   assert(Fn && !Fn->CoroutineStmts.empty() && "not a coroutine");
@@ -388,21 +525,9 @@
     Diag(Fn->FirstReturnLoc, diag::err_return_in_coroutine);
     auto *First = Fn->CoroutineStmts[0];
     Diag(First->getLocStart(), diag::note_declared_coroutine_here)
-      << (isa<CoawaitExpr>(First) ? 0 :
-          isa<CoyieldExpr>(First) ? 1 : 2);
-  }
-
-  bool AnyCoawaits = false;
-  bool AnyCoyields = false;
-  for (auto *CoroutineStmt : Fn->CoroutineStmts) {
-    AnyCoawaits |= isa<CoawaitExpr>(CoroutineStmt);
-    AnyCoyields |= isa<CoyieldExpr>(CoroutineStmt);
+        << (isa<CoawaitExpr>(First) ? 0 : isa<CoyieldExpr>(First) ? 1 : 2);
   }
 
-  if (!AnyCoawaits && !AnyCoyields)
-    Diag(Fn->CoroutineStmts.front()->getLocStart(),
-         diag::ext_coroutine_without_co_await_co_yield);
-
   SourceLocation Loc = FD->getLocation();
 
   // Form a declaration statement for the promise declaration, so that AST
@@ -432,15 +557,31 @@
   if (FinalSuspend.isInvalid())
     return FD->setInvalidDecl();
 
+  // Add a label to a final suspend. It will be the jump target for co_return
+  // statements.
+  LabelDecl *FinalLabel =
+      LabelDecl::Create(Context, CurContext, SourceLocation(),
+                        PP.getIdentifierInfo("coro.final.label"));
+  StmtResult FinalSuspendWithLabel =
+      ActOnLabelStmt(Loc, FinalLabel, Loc, FinalSuspend.get());
+  if (FinalSuspendWithLabel.isInvalid())
+    return FD->setInvalidDecl();
+
+  // Build allocation function and deallocation expressions.
+  Expr *Allocation = nullptr;
+  LabelStmt *Deallocation = nullptr;
+  if (!buildAllocationAndDeallocation(*this, Loc, Fn, Allocation, Deallocation))
+    return FD->setInvalidDecl();
+
   // FIXME: Perform analysis of set_exception call.
 
   // FIXME: Try to form 'p.return_void();' expression statement to handle
   // control flowing off the end of the coroutine.
 
   // Build implicit 'p.get_return_object()' expression and form initialization
   // of return type from it.
   ExprResult ReturnObject =
-    buildPromiseCall(*this, Fn, Loc, "get_return_object", None);
+      buildPromiseCall(*this, Fn, Loc, "get_return_object", None);
   if (ReturnObject.isInvalid())
     return FD->setInvalidDecl();
   QualType RetType = FD->getReturnType();
@@ -457,11 +598,12 @@
     return FD->setInvalidDecl();
 
   // FIXME: Perform move-initialization of parameters into frame-local copies.
-  SmallVector<Expr*, 16> ParamMoves;
+  SmallVector<Expr *, 16> ParamMoves;
 
   // Build body for the coroutine wrapper statement.
   Body = new (Context) CoroutineBodyStmt(
-      Body, PromiseStmt.get(), InitialSuspend.get(), FinalSuspend.get(),
-      /*SetException*/nullptr, /*Fallthrough*/nullptr,
-      ReturnObject.get(), ParamMoves);
+      Body, PromiseStmt.get(), InitialSuspend.get(),
+      cast_or_null<LabelStmt>(FinalSuspendWithLabel.get()),
+      /*SetException*/ nullptr, /*Fallthrough*/ nullptr, Allocation,
+      Deallocation, ReturnObject.get(), ParamMoves);
 }
Index: lib/CodeGen/CodeGenFunction.h
===================================================================
--- lib/CodeGen/CodeGenFunction.h
+++ lib/CodeGen/CodeGenFunction.h
@@ -2301,6 +2301,8 @@
   void EmitObjCAtSynchronizedStmt(const ObjCAtSynchronizedStmt &S);
   void EmitObjCAutoreleasePoolStmt(const ObjCAutoreleasePoolStmt &S);
 
+  void EmitCoroutineBody(const CoroutineBodyStmt &S);
+  void EmitCoreturnStmt(const CoreturnStmt &S);
   RValue EmitCoroutineIntrinsic(const CallExpr *E, unsigned int IID);
 
   void EnterCXXTryStmt(const CXXTryStmt &S, bool IsFnTryBlock = false);
Index: lib/CodeGen/CGStmt.cpp
===================================================================
--- lib/CodeGen/CGStmt.cpp
+++ lib/CodeGen/CGStmt.cpp
@@ -142,9 +142,11 @@
   case Stmt::GCCAsmStmtClass:   // Intentional fall-through.
   case Stmt::MSAsmStmtClass:    EmitAsmStmt(cast<AsmStmt>(*S));           break;
   case Stmt::CoroutineBodyStmtClass:
-  case Stmt::CoreturnStmtClass:
-    CGM.ErrorUnsupported(S, "coroutine");
+    EmitCoroutineBody(cast<CoroutineBodyStmt>(*S));
     break;
+  case Stmt::CoreturnStmtClass:
+    EmitCoreturnStmt(cast<CoreturnStmt>(*S));
+    break;    
   case Stmt::CapturedStmtClass: {
     const CapturedStmt *CS = cast<CapturedStmt>(S);
     EmitCapturedStmt(*CS, CS->getCapturedRegionKind());
Index: lib/CodeGen/CGCoroutine.cpp
===================================================================
--- lib/CodeGen/CGCoroutine.cpp
+++ lib/CodeGen/CGCoroutine.cpp
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "CodeGenFunction.h"
+#include "clang/AST/StmtCXX.h"
 
 using namespace clang;
 using namespace CodeGen;
@@ -58,6 +59,27 @@
   return true;
 }
 
+void CodeGenFunction::EmitCoreturnStmt(const CoreturnStmt &S) {
+  EmitStmt(S.getPromiseCall());
+  // FIXME: Jump to final suspend label.
+}
+
+void CodeGenFunction::EmitCoroutineBody(const CoroutineBodyStmt &S) {
+  auto *NullPtr = llvm::ConstantPointerNull::get(Builder.getInt8PtrTy());
+  // FIXME: Instead of 0, pass an equivalent of alignas(maxalign_t).
+  auto *CoroId =
+      Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::coro_id),
+                         {Builder.getInt32(0), NullPtr, NullPtr, NullPtr});
+  if (!createCoroData(*this, CurCoro, CoroId, nullptr)) {
+    // User inserted __builtin_coro_id by hand. Should not try to emit anything.
+    return;
+  }
+
+  EmitScalarExpr(S.getAllocate());
+  // FIXME: Emit the rest of the coroutine.
+  EmitStmt(S.getDeallocate());
+}
+
 // Emit coroutine intrinsic and patch up arguments of the token type.
 RValue CodeGenFunction::EmitCoroutineIntrinsic(const CallExpr *E,
                                                unsigned int IID) {
Index: include/clang/Basic/DiagnosticSemaKinds.td
===================================================================
--- include/clang/Basic/DiagnosticSemaKinds.td
+++ include/clang/Basic/DiagnosticSemaKinds.td
@@ -8567,10 +8567,6 @@
   "'main' cannot be a coroutine">;
 def err_coroutine_varargs : Error<
   "'%0' cannot be used in a varargs function">;
-def ext_coroutine_without_co_await_co_yield : ExtWarn<
-  "'co_return' used in a function "
-  "that uses neither 'co_await' nor 'co_yield'">,
-  InGroup<DiagGroup<"coreturn-without-coawait">>;
 def err_implied_std_coroutine_traits_not_found : Error<
   "you need to include <experimental/coroutine> before defining a coroutine">;
 def err_malformed_std_coroutine_traits : Error<
Index: include/clang/AST/StmtCXX.h
===================================================================
--- include/clang/AST/StmtCXX.h
+++ include/clang/AST/StmtCXX.h
@@ -304,23 +304,28 @@
     FinalSuspend,  ///< The final suspend statement, run after the body.
     OnException,   ///< Handler for exceptions thrown in the body.
     OnFallthrough, ///< Handler for control flow falling off the body.
+    Allocate,      ///< Coroutine frame memory allocation.
+    Deallocate,    ///< Coroutine frame memory deallocation.
     ReturnValue,   ///< Return value for thunk function.
     FirstParamMove ///< First offset for move construction of parameter copies.
   };
   Stmt *SubStmts[SubStmt::FirstParamMove];
 
   friend class ASTStmtReader;
 public:
   CoroutineBodyStmt(Stmt *Body, Stmt *Promise, Stmt *InitSuspend,
-                    Stmt *FinalSuspend, Stmt *OnException, Stmt *OnFallthrough,
+                    LabelStmt *FinalSuspend, Stmt *OnException,
+                    Stmt *OnFallthrough, Expr *Allocate, LabelStmt *Deallocate,
                     Expr *ReturnValue, ArrayRef<Expr *> ParamMoves)
       : Stmt(CoroutineBodyStmtClass) {
     SubStmts[CoroutineBodyStmt::Body] = Body;
     SubStmts[CoroutineBodyStmt::Promise] = Promise;
     SubStmts[CoroutineBodyStmt::InitSuspend] = InitSuspend;
     SubStmts[CoroutineBodyStmt::FinalSuspend] = FinalSuspend;
     SubStmts[CoroutineBodyStmt::OnException] = OnException;
     SubStmts[CoroutineBodyStmt::OnFallthrough] = OnFallthrough;
+    SubStmts[CoroutineBodyStmt::Allocate] = Allocate;
+    SubStmts[CoroutineBodyStmt::Deallocate] = Deallocate;
     SubStmts[CoroutineBodyStmt::ReturnValue] = ReturnValue;
     // FIXME: Tail-allocate space for parameter move expressions and store them.
     assert(ParamMoves.empty() && "not implemented yet");
@@ -338,13 +343,20 @@
   }
 
   Stmt *getInitSuspendStmt() const { return SubStmts[SubStmt::InitSuspend]; }
-  Stmt *getFinalSuspendStmt() const { return SubStmts[SubStmt::FinalSuspend]; }
+  LabelStmt *getFinalSuspendStmt() const {
+    return cast<LabelStmt>(SubStmts[SubStmt::FinalSuspend]);
+  }
 
   Stmt *getExceptionHandler() const { return SubStmts[SubStmt::OnException]; }
   Stmt *getFallthroughHandler() const {
     return SubStmts[SubStmt::OnFallthrough];
   }
 
+  Expr *getAllocate() const { return cast<Expr>(SubStmts[SubStmt::Allocate]); }
+  LabelStmt *getDeallocate() const {
+    return cast<LabelStmt>(SubStmts[SubStmt::Deallocate]);
+  }
+
   Expr *getReturnValueInit() const {
     return cast<Expr>(SubStmts[SubStmt::ReturnValue]);
   }
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to