modocache updated this revision to Diff 134304.
modocache added a comment.

Thanks for the review, @GorNishanov, and for pointing out that part of the 
standard! I added a reference to the implicit object parameter, as well as some 
tests for that behavior. And thanks for pointing out the flaw in 
https://reviews.llvm.org/D41820 -- I'll submit a separate diff to forward the 
implicit object parameter to promise type constructors.


Repository:
  rC Clang

https://reviews.llvm.org/D42606

Files:
  lib/Sema/SemaCoroutine.cpp
  test/CodeGenCoroutines/coro-alloc.cpp
  test/CodeGenCoroutines/coro-gro-nrvo.cpp
  test/CodeGenCoroutines/coro-params.cpp
  test/SemaCXX/coroutines.cpp

Index: test/SemaCXX/coroutines.cpp
===================================================================
--- test/SemaCXX/coroutines.cpp
+++ test/SemaCXX/coroutines.cpp
@@ -781,6 +781,102 @@
 }
 template coro<good_promise_13> dependent_uses_nothrow_new(good_promise_13);
 
+struct good_promise_custom_new_operator {
+  coro<good_promise_custom_new_operator> get_return_object();
+  suspend_always initial_suspend();
+  suspend_always final_suspend();
+  void return_void();
+  void unhandled_exception();
+  void *operator new(unsigned long, double, float, int);
+};
+
+coro<good_promise_custom_new_operator>
+good_coroutine_calls_custom_new_operator(double, float, int) {
+  co_return;
+}
+
+struct coroutine_nonstatic_member_struct;
+
+struct good_promise_nonstatic_member_custom_new_operator {
+  coro<good_promise_nonstatic_member_custom_new_operator> get_return_object();
+  suspend_always initial_suspend();
+  suspend_always final_suspend();
+  void return_void();
+  void unhandled_exception();
+  void *operator new(unsigned long, coroutine_nonstatic_member_struct &, double);
+};
+
+struct bad_promise_nonstatic_member_mismatched_custom_new_operator {
+  coro<bad_promise_nonstatic_member_mismatched_custom_new_operator> get_return_object();
+  suspend_always initial_suspend();
+  suspend_always final_suspend();
+  void return_void();
+  void unhandled_exception();
+  // expected-note@+1 {{candidate function not viable: requires 2 arguments, but 1 was provided}}
+  void *operator new(unsigned long, double);
+};
+
+struct coroutine_nonstatic_member_struct {
+  coro<good_promise_nonstatic_member_custom_new_operator>
+  good_coroutine_calls_nonstatic_member_custom_new_operator(double) {
+    co_return;
+  }
+
+  coro<bad_promise_nonstatic_member_mismatched_custom_new_operator>
+  bad_coroutine_calls_nonstatic_member_mistmatched_custom_new_operator(double) {
+    // expected-error@-1 {{no matching function for call to 'operator new'}}
+    co_return;
+  }
+};
+
+struct bad_promise_mismatched_custom_new_operator {
+  coro<bad_promise_mismatched_custom_new_operator> get_return_object();
+  suspend_always initial_suspend();
+  suspend_always final_suspend();
+  void return_void();
+  void unhandled_exception();
+  // expected-note@+1 {{candidate function not viable: requires 4 arguments, but 1 was provided}}
+  void *operator new(unsigned long, double, float, int);
+};
+
+coro<bad_promise_mismatched_custom_new_operator>
+bad_coroutine_calls_mismatched_custom_new_operator(double) {
+  // expected-error@-1 {{no matching function for call to 'operator new'}}
+  co_return;
+}
+
+struct bad_promise_throwing_custom_new_operator {
+  static coro<bad_promise_throwing_custom_new_operator> get_return_object_on_allocation_failure();
+  coro<bad_promise_throwing_custom_new_operator> get_return_object();
+  suspend_always initial_suspend();
+  suspend_always final_suspend();
+  void return_void();
+  void unhandled_exception();
+  // expected-error@+1 {{'operator new' is required to have a non-throwing noexcept specification when the promise type declares 'get_return_object_on_allocation_failure()'}}
+  void *operator new(unsigned long, double, float, int);
+};
+
+coro<bad_promise_throwing_custom_new_operator>
+bad_coroutine_calls_throwing_custom_new_operator(double, float, int) {
+  // expected-note@-1 {{call to 'operator new' implicitly required by coroutine function here}}
+  co_return;
+}
+
+struct good_promise_noexcept_custom_new_operator {
+  static coro<good_promise_noexcept_custom_new_operator> get_return_object_on_allocation_failure();
+  coro<good_promise_noexcept_custom_new_operator> get_return_object();
+  suspend_always initial_suspend();
+  suspend_always final_suspend();
+  void return_void();
+  void unhandled_exception();
+  void *operator new(unsigned long, double, float, int) noexcept;
+};
+
+coro<good_promise_noexcept_custom_new_operator>
+good_coroutine_calls_noexcept_custom_new_operator(double, float, int) {
+  co_return;
+}
+
 struct mismatch_gro_type_tag1 {};
 template<>
 struct std::experimental::coroutine_traits<int, mismatch_gro_type_tag1> {
Index: test/CodeGenCoroutines/coro-params.cpp
===================================================================
--- test/CodeGenCoroutines/coro-params.cpp
+++ test/CodeGenCoroutines/coro-params.cpp
@@ -69,12 +69,12 @@
   // CHECK: store i32 %val, i32* %[[ValAddr:.+]]
 
   // CHECK: call i8* @llvm.coro.begin(
-  // CHECK-NEXT: call void @_ZN8MoveOnlyC1EOS_(%struct.MoveOnly* %[[MoCopy]], %struct.MoveOnly* dereferenceable(4) %[[MoParam]])
+  // CHECK: call void @_ZN8MoveOnlyC1EOS_(%struct.MoveOnly* %[[MoCopy]], %struct.MoveOnly* dereferenceable(4) %[[MoParam]])
   // CHECK-NEXT: call void @_ZN11MoveAndCopyC1EOS_(%struct.MoveAndCopy* %[[McCopy]], %struct.MoveAndCopy* dereferenceable(4) %[[McParam]]) #
   // CHECK-NEXT: invoke void @_ZNSt12experimental16coroutine_traitsIJvi8MoveOnly11MoveAndCopyEE12promise_typeC1Ev(
 
   // CHECK: call void @_ZN14suspend_always12await_resumeEv(
-  // CHECK: %[[IntParam:.+]] = load i32, i32* %val.addr
+  // CHECK: %[[IntParam:.+]] = load i32, i32* %val1
   // CHECK: %[[MoGep:.+]] = getelementptr inbounds %struct.MoveOnly, %struct.MoveOnly* %[[MoCopy]], i32 0, i32 0
   // CHECK: %[[MoVal:.+]] = load i32, i32* %[[MoGep]]
   // CHECK: %[[McGep:.+]] =  getelementptr inbounds %struct.MoveAndCopy, %struct.MoveAndCopy* %[[McCopy]], i32 0, i32 0
@@ -150,9 +150,9 @@
 
 // CHECK-LABEL: void @_Z38coroutine_matching_promise_constructor28promise_matching_constructorifd(i32, float, double)
 void coroutine_matching_promise_constructor(promise_matching_constructor, int, float, double) {
-  // CHECK: %[[INT:.+]] = load i32, i32* %.addr, align 4
-  // CHECK: %[[FLOAT:.+]] = load float, float* %.addr1, align 4
-  // CHECK: %[[DOUBLE:.+]] = load double, double* %.addr2, align 8
+  // CHECK: %[[INT:.+]] = load i32, i32* %5, align 4
+  // CHECK: %[[FLOAT:.+]] = load float, float* %6, align 4
+  // CHECK: %[[DOUBLE:.+]] = load double, double* %7, align 8
   // CHECK: invoke void @_ZNSt12experimental16coroutine_traitsIJv28promise_matching_constructorifdEE12promise_typeC1ES1_ifd(%"struct.std::experimental::coroutine_traits<void, promise_matching_constructor, int, float, double>::promise_type"* %__promise, i32 %[[INT]], float %[[FLOAT]], double %[[DOUBLE]])
   co_return;
 }
Index: test/CodeGenCoroutines/coro-gro-nrvo.cpp
===================================================================
--- test/CodeGenCoroutines/coro-gro-nrvo.cpp
+++ test/CodeGenCoroutines/coro-gro-nrvo.cpp
@@ -41,7 +41,7 @@
 
 // CHECK: {{.*}}[[CoroInit]]:
 // CHECK: store i1 false, i1* %gro.active
-// CHECK-NEXT: call void @{{.*get_return_objectEv}}(%struct.coro* sret %agg.result
+// CHECK: call void @{{.*get_return_objectEv}}(%struct.coro* sret %agg.result
 // CHECK-NEXT: store i1 true, i1* %gro.active
   co_return;
 }
@@ -78,7 +78,7 @@
 
 // CHECK: {{.*}}[[InitOnSuccess]]:
 // CHECK: store i1 false, i1* %gro.active
-// CHECK-NEXT: call void @{{.*get_return_objectEv}}(%struct.coro_two* sret %agg.result
+// CHECK: call void @{{.*get_return_objectEv}}(%struct.coro_two* sret %agg.result
 // CHECK-NEXT: store i1 true, i1* %gro.active
 
 // CHECK: [[RetLabel]]:
Index: test/CodeGenCoroutines/coro-alloc.cpp
===================================================================
--- test/CodeGenCoroutines/coro-alloc.cpp
+++ test/CodeGenCoroutines/coro-alloc.cpp
@@ -106,6 +106,34 @@
   co_return;
 }
 
+struct promise_matching_placement_new_tag {};
+
+template<>
+struct std::experimental::coroutine_traits<void, promise_matching_placement_new_tag, int, float, double> {
+  struct promise_type {
+    void *operator new(unsigned long, promise_matching_placement_new_tag,
+                       int, float, double);
+    void get_return_object() {}
+    suspend_always initial_suspend() { return {}; }
+    suspend_always final_suspend() { return {}; }
+    void return_void() {}
+  };
+};
+
+// CHECK-LABEL: f1a(
+extern "C" void f1a(promise_matching_placement_new_tag, int x, float y , double z) {
+  // CHECK: store i32 %x, i32* %x.addr, align 4
+  // CHECK: store float %y, float* %y.addr, align 4
+  // CHECK: store double %z, double* %z.addr, align 8
+  // CHECK: %[[ID:.+]] = call token @llvm.coro.id(i32 16
+  // CHECK: %[[SIZE:.+]] = call i64 @llvm.coro.size.i64()
+  // CHECK: %[[INT:.+]] = load i32, i32* %x.addr, align 4
+  // CHECK: %[[FLOAT:.+]] = load float, float* %y.addr, align 4
+  // CHECK: %[[DOUBLE:.+]] = load double, double* %z.addr, align 8
+  // CHECK: call i8* @_ZNSt12experimental16coroutine_traitsIJv34promise_matching_placement_new_tagifdEE12promise_typenwEmS1_ifd(i64 %[[SIZE]], i32 %[[INT]], float %[[FLOAT]], double %[[DOUBLE]])
+  co_return;
+}
+
 struct promise_delete_tag {};
 
 template<>
Index: lib/Sema/SemaCoroutine.cpp
===================================================================
--- lib/Sema/SemaCoroutine.cpp
+++ lib/Sema/SemaCoroutine.cpp
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "CoroutineStmtBuilder.h"
+#include "clang/AST/ASTLambda.h"
 #include "clang/AST/Decl.h"
 #include "clang/AST/ExprCXX.h"
 #include "clang/AST/StmtCXX.h"
@@ -506,24 +507,15 @@
 
     auto RefExpr = ExprEmpty();
     auto Move = Moves.find(PD);
-    if (Move != Moves.end()) {
-      // If a reference to the function parameter exists in the coroutine
-      // frame, use that reference.
-      auto *MoveDecl =
-          cast<VarDecl>(cast<DeclStmt>(Move->second)->getSingleDecl());
-      RefExpr = BuildDeclRefExpr(MoveDecl, MoveDecl->getType(),
-                                 ExprValueKind::VK_LValue, FD->getLocation());
-    } else {
-      // If the function parameter doesn't exist in the coroutine frame, it
-      // must be a scalar value. Use it directly.
-      assert(!PD->getType()->getAsCXXRecordDecl() &&
-             "Non-scalar types should have been moved and inserted into the "
-             "parameter moves map");
-      RefExpr =
-          BuildDeclRefExpr(PD, PD->getOriginalType().getNonReferenceType(),
-                           ExprValueKind::VK_LValue, FD->getLocation());
-    }
-
+    assert(Move != Moves.end() &&
+           "Coroutine function parameter not inserted into move map");
+    // If a reference to the function parameter exists in the coroutine
+    // frame, use that reference.
+    auto *MoveDecl =
+        cast<VarDecl>(cast<DeclStmt>(Move->second)->getSingleDecl());
+    RefExpr =
+        BuildDeclRefExpr(MoveDecl, MoveDecl->getType().getNonReferenceType(),
+                         ExprValueKind::VK_LValue, FD->getLocation());
     if (RefExpr.isInvalid())
       return nullptr;
     CtorArgExprs.push_back(RefExpr.get());
@@ -1050,18 +1042,75 @@
 
   const bool RequiresNoThrowAlloc = ReturnStmtOnAllocFailure != nullptr;
 
-  // FIXME: Add support for stateful allocators.
+  // [dcl.fct.def.coroutine]/7
+  // Lookup allocation functions using a parameter list composed of the
+  // requested size of the coroutine state being allocated, followed by
+  // the coroutine function's arguments. If a matching allocation function
+  // exists, use it. Otherwise, use an allocation function that just takes
+  // the requested size.
 
   FunctionDecl *OperatorNew = nullptr;
   FunctionDecl *OperatorDelete = nullptr;
   FunctionDecl *UnusedResult = nullptr;
   bool PassAlignment = false;
   SmallVector<Expr *, 1> PlacementArgs;
 
+  // [dcl.fct.def.coroutine]/7
+  // "The allocation function’s name is looked up in the scope of P.
+  // [...] If the lookup finds an allocation function in the scope of P,
+  // overload resolution is performed on a function call created by assembling
+  // an argument list. The first argument is the amount of space requested,
+  // and has type std::size_t. The lvalues p1 ... pn are the succeeding
+  // arguments."
+  //
+  // ...where "p1 ... pn" are defined earlier as:
+  //
+  // [dcl.fct.def.coroutine]/3
+  // "For a coroutine f that is a non-static member function, let P1 denote the
+  // type of the implicit object parameter (13.3.1) and P2 ... Pn be the types
+  // of the function parameters; otherwise let P1 ... Pn be the types of the
+  // function parameters. Let p1 ... pn be lvalues denoting those objects."
+  if (auto *MD = dyn_cast<CXXMethodDecl>(&FD)) {
+    if (MD->isInstance() && !isLambdaCallOperator(MD)) {
+      ExprResult ThisExpr = S.ActOnCXXThis(Loc);
+      if (ThisExpr.isInvalid())
+        return false;
+      ThisExpr = S.CreateBuiltinUnaryOp(Loc, UO_Deref, ThisExpr.get());
+      if (ThisExpr.isInvalid())
+        return false;
+      PlacementArgs.push_back(ThisExpr.get());
+    }
+  }
+  for (auto *PD : FD.parameters()) {
+    if (PD->getType()->isDependentType())
+      continue;
+
+    // Build a reference to the parameter.
+    auto PDLoc = PD->getLocation();
+    ExprResult PDRefExpr =
+        S.BuildDeclRefExpr(PD, PD->getOriginalType().getNonReferenceType(),
+                           ExprValueKind::VK_LValue, PDLoc);
+    if (PDRefExpr.isInvalid())
+      return false;
+
+    PlacementArgs.push_back(PDRefExpr.get());
+  }
   S.FindAllocationFunctions(Loc, SourceRange(),
                             /*UseGlobal*/ false, PromiseType,
                             /*isArray*/ false, PassAlignment, PlacementArgs,
-                            OperatorNew, UnusedResult);
+                            OperatorNew, UnusedResult, /*Diagnose*/ false);
+
+  // [dcl.fct.def.coroutine]/7
+  // "If no matching function is found, overload resolution is performed again
+  // on a function call created by passing just the amount of space required as
+  // an argument of type std::size_t."
+  if (!OperatorNew && !PlacementArgs.empty()) {
+    PlacementArgs.clear();
+    S.FindAllocationFunctions(Loc, SourceRange(),
+                              /*UseGlobal*/ false, PromiseType,
+                              /*isArray*/ false, PassAlignment,
+                              PlacementArgs, OperatorNew, UnusedResult);
+  }
 
   bool IsGlobalOverload =
       OperatorNew && !isa<CXXRecordDecl>(OperatorNew->getDeclContext());
@@ -1080,7 +1129,8 @@
                               OperatorNew, UnusedResult);
   }
 
-  assert(OperatorNew && "expected definition of operator new to be found");
+  if (!OperatorNew)
+    return false;
 
   if (RequiresNoThrowAlloc) {
     const auto *FT = OperatorNew->getType()->getAs<FunctionProtoType>();
@@ -1386,25 +1436,28 @@
     if (PD->getType()->isDependentType())
       continue;
 
-    // No need to copy scalars, LLVM will take care of them.
-    if (PD->getType()->getAsCXXRecordDecl()) {
-      ExprResult PDRefExpr = BuildDeclRefExpr(
-          PD, PD->getType(), ExprValueKind::VK_LValue, Loc); // FIXME: scope?
-      if (PDRefExpr.isInvalid())
-        return false;
+    ExprResult PDRefExpr =
+        BuildDeclRefExpr(PD, PD->getType().getNonReferenceType(),
+                         ExprValueKind::VK_LValue, Loc); // FIXME: scope?
+    if (PDRefExpr.isInvalid())
+      return false;
 
-      Expr *CExpr = castForMoving(*this, PDRefExpr.get());
+    Expr *CExpr = nullptr;
+    if (PD->getType()->getAsCXXRecordDecl() ||
+        PD->getType()->isRValueReferenceType())
+      CExpr = castForMoving(*this, PDRefExpr.get());
+    else
+      CExpr = PDRefExpr.get();
 
-      auto D = buildVarDecl(*this, Loc, PD->getType(), PD->getIdentifier());
-      AddInitializerToDecl(D, CExpr, /*DirectInit=*/true);
+    auto D = buildVarDecl(*this, Loc, PD->getType(), PD->getIdentifier());
+    AddInitializerToDecl(D, CExpr, /*DirectInit=*/true);
 
-      // Convert decl to a statement.
-      StmtResult Stmt = ActOnDeclStmt(ConvertDeclToDeclGroup(D), Loc, Loc);
-      if (Stmt.isInvalid())
-        return false;
+    // Convert decl to a statement.
+    StmtResult Stmt = ActOnDeclStmt(ConvertDeclToDeclGroup(D), Loc, Loc);
+    if (Stmt.isInvalid())
+      return false;
 
-      ScopeInfo->CoroutineParameterMoves.insert(std::make_pair(PD, Stmt.get()));
-    }
+    ScopeInfo->CoroutineParameterMoves.insert(std::make_pair(PD, Stmt.get()));
   }
   return true;
 }
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to