EricWF created this revision.
EricWF added reviewers: rsmith, GorNishanov.
EricWF added a subscriber: cfe-commits.
Herald added a subscriber: mehdi_amini.

This patch adds passing a `coroutine_handle` object to `await_suspend` calls.

It builds the `coroutine_handle`  using 
`coroutine_handle<PromiseType>::from_address(__builtin_coro_frame())`.


https://reviews.llvm.org/D26316

Files:
  include/clang/Basic/DiagnosticSemaKinds.td
  lib/Sema/SemaCoroutine.cpp
  test/CodeGenCoroutines/coro-alloc.cpp
  test/SemaCXX/coroutines.cpp

Index: test/SemaCXX/coroutines.cpp
===================================================================
--- test/SemaCXX/coroutines.cpp
+++ test/SemaCXX/coroutines.cpp
@@ -16,33 +16,26 @@
   // expected-error@-1 {{use of undeclared identifier 'a'}}
 }
 
-
-struct awaitable {
-  bool await_ready();
-  void await_suspend(); // FIXME: coroutine_handle
-  void await_resume();
-} a;
-
-struct suspend_always {
-  bool await_ready() { return false; }
-  void await_suspend() {}
-  void await_resume() {}
-};
-
-struct suspend_never {
-  bool await_ready() { return true; }
-  void await_suspend() {}
-  void await_resume() {}
-};
-
 void no_coroutine_traits() {
-  co_await a; // expected-error {{need to include <experimental/coroutine>}}
+  co_await 4; // expected-error {{need to include <experimental/coroutine>}}
 }
 
 namespace std {
 namespace experimental {
 template <typename... T>
 struct coroutine_traits; // expected-note {{declared here}}
+
+template <class PromiseType = void>
+struct coroutine_handle {
+  static coroutine_handle from_address(void *);
+};
+
+template <>
+struct coroutine_handle<void> {
+  template <class PromiseType>
+  coroutine_handle(coroutine_handle<PromiseType>);
+  static coroutine_handle from_address(void *);
+};
 }
 }
 
@@ -52,6 +45,24 @@
   using promise_type = Promise;
 };
 
+struct awaitable {
+  bool await_ready();
+  void await_suspend(std::experimental::coroutine_handle<>); // FIXME: coroutine_handle
+  void await_resume();
+} a;
+
+struct suspend_always {
+  bool await_ready() { return false; }
+  void await_suspend(std::experimental::coroutine_handle<>) {}
+  void await_resume() {}
+};
+
+struct suspend_never {
+  bool await_ready() { return true; }
+  void await_suspend(std::experimental::coroutine_handle<>) {}
+  void await_resume() {}
+};
+
 void no_specialization() {
   co_await a; // expected-error {{implicit instantiation of undefined template 'std::experimental::coroutine_traits<void>'}}
 }
@@ -86,13 +97,6 @@
 struct std::experimental::coroutine_traits<void, void_tag, T...>
 { using promise_type = promise_void; };
 
-namespace std {
-namespace experimental {
-template <typename Promise = void>
-struct coroutine_handle;
-}
-}
-
 // FIXME: This diagnostic is terrible.
 void undefined_promise() { // expected-error {{this function cannot be a coroutine: 'experimental::coroutine_traits<void>::promise_type' (aka 'promise') is an incomplete type}}
   co_await a;
Index: test/CodeGenCoroutines/coro-alloc.cpp
===================================================================
--- test/CodeGenCoroutines/coro-alloc.cpp
+++ test/CodeGenCoroutines/coro-alloc.cpp
@@ -4,12 +4,26 @@
 namespace experimental {
 template <typename... T>
 struct coroutine_traits; // expected-note {{declared here}}
+
+template <class Promise = void>
+struct coroutine_handle {
+  coroutine_handle() = default;
+  static coroutine_handle from_address(void *) { return {}; }
+};
+
+template <>
+struct coroutine_handle<void> {
+  static coroutine_handle from_address(void *) { return {}; }
+  coroutine_handle() = default;
+  template <class PromiseType>
+  coroutine_handle(coroutine_handle<PromiseType>) {}
+};
 }
 }
 
 struct suspend_always {
   bool await_ready() { return false; }
-  void await_suspend() {}
+  void await_suspend(std::experimental::coroutine_handle<>) {}
   void await_resume() {}
 };
 
Index: lib/Sema/SemaCoroutine.cpp
===================================================================
--- lib/Sema/SemaCoroutine.cpp
+++ lib/Sema/SemaCoroutine.cpp
@@ -116,6 +116,53 @@
   return PromiseType;
 }
 
+/// Look up the std::coroutine_traits<...>::promise_type for the given
+/// function type.
+static QualType lookupCoroutineHandleType(Sema &S, QualType PromiseType,
+                                          SourceLocation Loc) {
+  if (PromiseType.isNull())
+    return QualType();
+
+  NamespaceDecl *StdExp = S.lookupStdExperimentalNamespace();
+  assert(StdExp && "Should already be diagnosed");
+
+  LookupResult Result(S, &S.PP.getIdentifierTable().get("coroutine_handle"),
+                      Loc, Sema::LookupOrdinaryName);
+  if (!S.LookupQualifiedName(Result, StdExp)) {
+    S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found);
+    return QualType();
+  }
+
+  ClassTemplateDecl *CoroHandle = Result.getAsSingle<ClassTemplateDecl>();
+  if (!CoroHandle) {
+    Result.suppressDiagnostics();
+    // We found something weird. Complain about the first thing we found.
+    NamedDecl *Found = *Result.begin();
+    S.Diag(Found->getLocation(), diag::err_malformed_std_coroutine_handle);
+    return QualType();
+  }
+
+  // Form template argument list for coroutine_traits<R, P1, P2, ...>.
+  TemplateArgumentListInfo Args(Loc, Loc);
+  Args.addArgument(TemplateArgumentLoc(
+      TemplateArgument(PromiseType),
+      S.Context.getTrivialTypeSourceInfo(PromiseType, Loc)));
+
+  // Build the template-id.
+  QualType CoroHandleType =
+      S.CheckTemplateIdType(TemplateName(CoroHandle), Loc, Args);
+  if (CoroHandleType.isNull())
+    return QualType();
+  if (S.RequireCompleteType(Loc, CoroHandleType,
+                            diag::err_coroutine_traits_missing_specialization))
+    return QualType();
+
+  auto *RD = CoroHandleType->getAsCXXRecordDecl();
+  assert(RD && "specialization of class template is not a class?");
+
+  return CoroHandleType;
+}
+
 static bool isValidCoroutineContext(Sema &S, SourceLocation Loc,
                                     StringRef Keyword) {
   // 'co_await' and 'co_yield' are not permitted in unevaluated operands.
@@ -260,20 +307,55 @@
   return S.ActOnCallExpr(nullptr, Result.get(), Loc, Args, Loc, nullptr);
 }
 
+static ExprResult buildCoroutineHandle(Sema &S, QualType PromiseType,
+                                       SourceLocation Loc) {
+  QualType HandleType = lookupCoroutineHandleType(S, PromiseType, Loc);
+  if (HandleType.isNull())
+    return ExprError();
+  auto *RD = HandleType->getAsCXXRecordDecl();
+  assert(RD && "must be class type");
+  DeclarationName DN = S.PP.getIdentifierInfo("from_address");
+  LookupResult LR(S, DN, Loc, Sema::LookupMemberName);
+  if (!S.LookupQualifiedName(LR, RD))
+    return ExprError();
+
+  Expr *FramePtr =
+      buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_frame, {});
+
+  // FIXME: Fix BuildMemberReferenceExpr to take a const CXXScopeSpec&.
+  CXXScopeSpec SS;
+  ExprResult Result = S.BuildMemberReferenceExpr(
+      /*BaseExpr*/ nullptr, HandleType, Loc, /*IsArrow=*/false, SS,
+      SourceLocation(), nullptr, LR, /*TemplateArgs=*/nullptr,
+      /*Scope=*/nullptr);
+  if (Result.isInvalid())
+    return ExprError();
+
+  return S.ActOnCallExpr(nullptr, Result.get(), Loc, FramePtr, Loc, nullptr);
+}
+
 /// Build calls to await_ready, await_suspend, and await_resume for a co_await
 /// expression.
-static ReadySuspendResumeResult buildCoawaitCalls(Sema &S, SourceLocation Loc,
-                                                  Expr *E) {
+static ReadySuspendResumeResult
+buildCoawaitCalls(Sema &S, SourceLocation Loc, QualType PromiseType, Expr *E) {
   // Assume invalid until we see otherwise.
   ReadySuspendResumeResult Calls = {true, {}};
 
+  ExprResult HandleExprRes = buildCoroutineHandle(S, PromiseType, Loc);
+  if (HandleExprRes.isInvalid())
+    return Calls;
+  Expr *HandleExpr = HandleExprRes.get();
+
   const StringRef Funcs[] = {"await_ready", "await_suspend", "await_resume"};
   for (size_t I = 0, N = llvm::array_lengthof(Funcs); I != N; ++I) {
     Expr *Operand = new (S.Context) OpaqueValueExpr(
       Loc, E->getType(), VK_LValue, E->getObjectKind(), E);
 
     // FIXME: Pass coroutine handle to await_suspend.
-    ExprResult Result = buildMemberCall(S, Operand, Loc, Funcs[I], None);
+    MultiExprArg Args = None;
+    if (Funcs[I] == "await_suspend")
+      Args = HandleExpr;
+    ExprResult Result = buildMemberCall(S, Operand, Loc, Funcs[I], Args);
     if (Result.isInvalid())
       return Calls;
     Calls.Results[I] = Result.get();
@@ -475,7 +557,8 @@
     E = CreateMaterializeTemporaryExpr(E->getType(), E, true);
 
   // Build the await_ready, await_suspend, await_resume calls.
-  ReadySuspendResumeResult RSS = buildCoawaitCalls(*this, Loc, E);
+  ReadySuspendResumeResult RSS =
+      buildCoawaitCalls(*this, Loc, Coroutine->CoroutinePromise->getType(), E);
   if (RSS.IsInvalid)
     return ExprError();
 
@@ -528,7 +611,8 @@
     E = CreateMaterializeTemporaryExpr(E->getType(), E, true);
 
   // Build the await_ready, await_suspend, await_resume calls.
-  ReadySuspendResumeResult RSS = buildCoawaitCalls(*this, Loc, E);
+  ReadySuspendResumeResult RSS =
+      buildCoawaitCalls(*this, Loc, Coroutine->CoroutinePromise->getType(), E);
   if (RSS.IsInvalid)
     return ExprError();
 
@@ -869,6 +953,8 @@
 
   // FIXME: Perform move-initialization of parameters into frame-local copies.
   SmallVector<Expr*, 16> ParamMoves;
+  // If we're instantiating a template then we have already replaced Body
+  // with a CoroutineBodyStmt.
   if (Body && !isa<CoroutineBodyStmt>(Body)) {
     StmtResult BodyRes = BuildCoroutineBodyStmt(
         Body, FSI->CoroutinePromise, FSI->CoroutineSuspends.first,
Index: include/clang/Basic/DiagnosticSemaKinds.td
===================================================================
--- include/clang/Basic/DiagnosticSemaKinds.td
+++ include/clang/Basic/DiagnosticSemaKinds.td
@@ -8656,6 +8656,8 @@
   "'std::experimental::coroutine_traits' must be a class template">;
 def err_implied_std_coroutine_traits_promise_type_not_found : Error<
   "this function cannot be a coroutine: %q0 has no member named 'promise_type'">;
+def err_malformed_std_coroutine_handle : Error<
+  "'std::experimental::coroutine_handle' must be a class template">;
 def err_implied_std_coroutine_traits_promise_type_not_class : Error<
   "this function cannot be a coroutine: %0 is not a class">;
 def err_coroutine_promise_type_incomplete : Error<
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to