This revision was automatically updated to reflect the committed changes.
GorNishanov marked an inline comment as done.
Closed by commit rL298891: [coroutines] Handle 
get_return_object_on_allocation_failure (authored by GorNishanov).

Changed prior to commit:
  https://reviews.llvm.org/D31399?vs=93177&id=93196#toc

Repository:
  rL LLVM

https://reviews.llvm.org/D31399

Files:
  cfe/trunk/include/clang/AST/StmtCXX.h
  cfe/trunk/include/clang/Basic/DiagnosticSemaKinds.td
  cfe/trunk/lib/AST/StmtCXX.cpp
  cfe/trunk/lib/CodeGen/CGCoroutine.cpp
  cfe/trunk/lib/Sema/SemaCoroutine.cpp
  cfe/trunk/test/CodeGenCoroutines/coro-alloc.cpp
  cfe/trunk/test/SemaCXX/coroutines.cpp

Index: cfe/trunk/test/SemaCXX/coroutines.cpp
===================================================================
--- cfe/trunk/test/SemaCXX/coroutines.cpp
+++ cfe/trunk/test/SemaCXX/coroutines.cpp
@@ -634,3 +634,21 @@
   //expected-note@-1 {{call to 'initial_suspend' implicitly required by the initial suspend point}}
   co_return; //expected-note {{function is a coroutine due to use of 'co_return' here}}
 }
+
+struct promise_on_alloc_failure_tag {};
+
+template<>
+struct std::experimental::coroutine_traits<int, promise_on_alloc_failure_tag> {
+  struct promise_type {
+    int get_return_object() {}
+    suspend_always initial_suspend() { return {}; }
+    suspend_always final_suspend() { return {}; }
+    void return_void() {}
+    int get_return_object_on_allocation_failure(); // expected-error{{'promise_type': 'get_return_object_on_allocation_failure()' must be a static member function}}
+    void unhandled_exception();
+  };
+};
+
+extern "C" int f(promise_on_alloc_failure_tag) {
+  co_return; //expected-note {{function is a coroutine due to use of 'co_return' here}}
+}
Index: cfe/trunk/test/CodeGenCoroutines/coro-alloc.cpp
===================================================================
--- cfe/trunk/test/CodeGenCoroutines/coro-alloc.cpp
+++ cfe/trunk/test/CodeGenCoroutines/coro-alloc.cpp
@@ -40,7 +40,7 @@
   };
 };
 
-// CHECK-LABEL: f0( 
+// CHECK-LABEL: f0(
 extern "C" void f0(global_new_delete_tag) {
   // CHECK: %[[ID:.+]] = call token @llvm.coro.id(i32 16
   // CHECK: %[[SIZE:.+]] = call i64 @llvm.coro.size.i64()
@@ -65,7 +65,7 @@
   };
 };
 
-// CHECK-LABEL: f1( 
+// CHECK-LABEL: f1(
 extern "C" void f1(promise_new_tag ) {
   // CHECK: %[[ID:.+]] = call token @llvm.coro.id(i32 16
   // CHECK: %[[SIZE:.+]] = call i64 @llvm.coro.size.i64()
@@ -90,7 +90,7 @@
   };
 };
 
-// CHECK-LABEL: f2( 
+// CHECK-LABEL: f2(
 extern "C" void f2(promise_delete_tag) {
   // CHECK: %[[ID:.+]] = call token @llvm.coro.id(i32 16
   // CHECK: %[[SIZE:.+]] = call i64 @llvm.coro.size.i64()
@@ -127,3 +127,30 @@
   // CHECK: call void @_ZNSt12experimental16coroutine_traitsIJv24promise_sized_delete_tagEE12promise_typedlEPvm(i8* %[[MEM]], i64 %[[SIZE2]])
   co_return;
 }
+
+struct promise_on_alloc_failure_tag {};
+
+template<>
+struct std::experimental::coroutine_traits<int, promise_on_alloc_failure_tag> {
+  struct promise_type {
+    int get_return_object() {}
+    suspend_always initial_suspend() { return {}; }
+    suspend_always final_suspend() { return {}; }
+    void return_void() {}
+    static int get_return_object_on_allocation_failure() { return -1; }
+  };
+};
+
+// CHECK-LABEL: f4(
+extern "C" int f4(promise_on_alloc_failure_tag) {
+  // CHECK: %[[ID:.+]] = call token @llvm.coro.id(i32 16
+  // CHECK: %[[SIZE:.+]] = call i64 @llvm.coro.size.i64()
+  // CHECK: %[[MEM:.+]] = call i8* @_Znwm(i64 %[[SIZE]])
+  // CHECK: %[[OK:.+]] = icmp ne i8* %[[MEM]], null
+  // CHECK: br i1 %[[OK]], label %[[OKBB:.+]], label %[[ERRBB:.+]]
+
+  // CHECK: [[ERRBB]]:
+  // CHECK: %[[RETVAL:.+]] = call i32 @_ZNSt12experimental16coroutine_traitsIJi28promise_on_alloc_failure_tagEE12promise_type39get_return_object_on_allocation_failureEv(
+  // CHECK: ret i32 %[[RETVAL]]
+  co_return;
+}
Index: cfe/trunk/lib/CodeGen/CGCoroutine.cpp
===================================================================
--- cfe/trunk/lib/CodeGen/CGCoroutine.cpp
+++ cfe/trunk/lib/CodeGen/CGCoroutine.cpp
@@ -229,7 +229,24 @@
   createCoroData(*this, CurCoro, CoroId);
   CurCoro.Data->SuspendBB = RetBB;
 
-  EmitScalarExpr(S.getAllocate());
+  auto *AllocateCall = EmitScalarExpr(S.getAllocate());
+
+  // Handle allocation failure if 'ReturnStmtOnAllocFailure' was provided.
+  if (auto *RetOnAllocFailure = S.getReturnStmtOnAllocFailure()) {
+    auto *RetOnFailureBB = createBasicBlock("coro.ret.on.failure");
+    auto *InitBB = createBasicBlock("coro.init");
+
+    // See if allocation was successful.
+    auto *NullPtr = llvm::ConstantPointerNull::get(Int8PtrTy);
+    auto *Cond = Builder.CreateICmpNE(AllocateCall, NullPtr);
+    Builder.CreateCondBr(Cond, InitBB, RetOnFailureBB);
+
+    // If not, return OnAllocFailure object.
+    EmitBlock(RetOnFailureBB);
+    EmitStmt(RetOnAllocFailure);
+
+    EmitBlock(InitBB);
+  }
 
   // FIXME: Setup cleanup scopes.
 
Index: cfe/trunk/lib/AST/StmtCXX.cpp
===================================================================
--- cfe/trunk/lib/AST/StmtCXX.cpp
+++ cfe/trunk/lib/AST/StmtCXX.cpp
@@ -108,6 +108,8 @@
   SubStmts[CoroutineBodyStmt::Allocate] = Args.Allocate;
   SubStmts[CoroutineBodyStmt::Deallocate] = Args.Deallocate;
   SubStmts[CoroutineBodyStmt::ReturnValue] = Args.ReturnValue;
+  SubStmts[CoroutineBodyStmt::ReturnStmtOnAllocFailure] =
+      Args.ReturnStmtOnAllocFailure;
   std::copy(Args.ParamMoves.begin(), Args.ParamMoves.end(),
             const_cast<Stmt **>(getParamMoves().data()));
 }
\ No newline at end of file
Index: cfe/trunk/lib/Sema/SemaCoroutine.cpp
===================================================================
--- cfe/trunk/lib/Sema/SemaCoroutine.cpp
+++ cfe/trunk/lib/Sema/SemaCoroutine.cpp
@@ -708,8 +708,8 @@
     }
     this->IsValid = makePromiseStmt() && makeInitialAndFinalSuspend() &&
                     makeOnException() && makeOnFallthrough() &&
-                    makeNewAndDeleteExpr() && makeReturnObject() &&
-                    makeParamMoves();
+                    makeReturnOnAllocFailure() && makeNewAndDeleteExpr() &&
+                    makeReturnObject() && makeParamMoves();
   }
 
   bool isInvalid() const { return !this->IsValid; }
@@ -720,6 +720,7 @@
   bool makeOnFallthrough();
   bool makeOnException();
   bool makeReturnObject();
+  bool makeReturnOnAllocFailure();
   bool makeParamMoves();
 };
 }
@@ -777,6 +778,66 @@
   return true;
 }
 
+static bool diagReturnOnAllocFailure(Sema &S, Expr *E,
+                                     CXXRecordDecl *PromiseRecordDecl,
+                                     FunctionScopeInfo &Fn) {
+  auto Loc = E->getExprLoc();
+  if (auto *DeclRef = dyn_cast_or_null<DeclRefExpr>(E)) {
+    auto *Decl = DeclRef->getDecl();
+    if (CXXMethodDecl *Method = dyn_cast_or_null<CXXMethodDecl>(Decl)) {
+      if (Method->isStatic())
+        return true;
+      else
+        Loc = Decl->getLocation();
+    }
+  }
+
+  S.Diag(
+      Loc,
+      diag::err_coroutine_promise_get_return_object_on_allocation_failure)
+      << PromiseRecordDecl;
+  S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here)
+      << Fn.getFirstCoroutineStmtKeyword();
+  return false;
+}
+
+bool SubStmtBuilder::makeReturnOnAllocFailure() {
+  if (!PromiseRecordDecl) return true;
+
+  // [dcl.fct.def.coroutine]/8
+  // The unqualified-id get_return_object_on_allocation_failure is looked up in
+  // the scope of class P by class member access lookup (3.4.5). ...
+  // If an allocation function returns nullptr, ... the coroutine return value
+  // is obtained by a call to ... get_return_object_on_allocation_failure().
+
+  DeclarationName DN =
+      S.PP.getIdentifierInfo("get_return_object_on_allocation_failure");
+  LookupResult Found(S, DN, Loc, Sema::LookupMemberName);
+  // Suppress diagnostics when a private member is selected. The same warnings
+  // will be produced again when building the call.
+  Found.suppressDiagnostics();
+  if (!S.LookupQualifiedName(Found, PromiseRecordDecl)) return true;
+
+  CXXScopeSpec SS;
+  ExprResult DeclNameExpr =
+      S.BuildDeclarationNameExpr(SS, Found, /*NeedsADL=*/false);
+  if (DeclNameExpr.isInvalid()) return false;
+
+  if (!diagReturnOnAllocFailure(S, DeclNameExpr.get(), PromiseRecordDecl, Fn))
+    return false;
+
+  ExprResult ReturnObjectOnAllocationFailure =
+      S.ActOnCallExpr(nullptr, DeclNameExpr.get(), Loc, {}, Loc);
+  if (ReturnObjectOnAllocationFailure.isInvalid()) return false;
+
+  StmtResult ReturnStmt = S.ActOnReturnStmt(
+      Loc, ReturnObjectOnAllocationFailure.get(), S.getCurScope());
+  if (ReturnStmt.isInvalid()) return false;
+
+  this->ReturnStmtOnAllocFailure = ReturnStmt.get();
+  return true;
+}
+
 bool SubStmtBuilder::makeNewAndDeleteExpr() {
   // Form and check allocation and deallocation calls.
   QualType PromiseType = Fn.CoroutinePromise->getType();
@@ -786,7 +847,8 @@
   if (S.RequireCompleteType(Loc, PromiseType, diag::err_incomplete_type))
     return false;
 
-  // FIXME: Add support for get_return_object_on_allocation failure.
+  // FIXME: Add nothrow_t placement arg for global alloc
+  //        if ReturnStmtOnAllocFailure != nullptr.
   // FIXME: Add support for stateful allocators.
 
   FunctionDecl *OperatorNew = nullptr;
Index: cfe/trunk/include/clang/AST/StmtCXX.h
===================================================================
--- cfe/trunk/include/clang/AST/StmtCXX.h
+++ cfe/trunk/include/clang/AST/StmtCXX.h
@@ -309,6 +309,7 @@
     Allocate,      ///< Coroutine frame memory allocation.
     Deallocate,    ///< Coroutine frame memory deallocation.
     ReturnValue,   ///< Return value for thunk function.
+    ReturnStmtOnAllocFailure, ///< Return statement if allocation failed.
     FirstParamMove ///< First offset for move construction of parameter copies.
   };
   unsigned NumParams;
@@ -332,6 +333,7 @@
     Expr *Allocate = nullptr;
     Expr *Deallocate = nullptr;
     Stmt *ReturnValue = nullptr;
+    Stmt *ReturnStmtOnAllocFailure = nullptr;
     ArrayRef<Stmt *> ParamMoves;
   };
 
@@ -379,6 +381,9 @@
   Expr *getReturnValueInit() const {
     return cast_or_null<Expr>(getStoredStmts()[SubStmt::ReturnValue]);
   }
+  Stmt *getReturnStmtOnAllocFailure() const {
+    return getStoredStmts()[SubStmt::ReturnStmtOnAllocFailure];
+  }
   ArrayRef<Stmt const *> getParamMoves() const {
     return {getStoredStmts() + SubStmt::FirstParamMove, NumParams};
   }
Index: cfe/trunk/include/clang/Basic/DiagnosticSemaKinds.td
===================================================================
--- cfe/trunk/include/clang/Basic/DiagnosticSemaKinds.td
+++ cfe/trunk/include/clang/Basic/DiagnosticSemaKinds.td
@@ -8878,6 +8878,8 @@
 def warn_coroutine_promise_unhandled_exception_required_with_exceptions : Warning<
   "%0 is required to declare the member 'unhandled_exception()' when exceptions are enabled">,
   InGroup<Coroutine>;
+def err_coroutine_promise_get_return_object_on_allocation_failure : Error<
+  "%0: 'get_return_object_on_allocation_failure()' must be a static member function">;
 }
 
 let CategoryName = "Documentation Issue" in {
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to