ChuanqiXu updated this revision to Diff 550181.
ChuanqiXu added a comment.

Address comments:

- Remove the complicated TypeVisitor, use the simple 
`getNonReferenceType()->getAsCXXRecordDecl()` form instead.
- Reword `*leak*` to `*escape*`.
- Use the suggested comments.


CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D157833/new/

https://reviews.llvm.org/D157833

Files:
  clang/docs/ReleaseNotes.rst
  clang/lib/CodeGen/CGCall.cpp
  clang/lib/CodeGen/CGCoroutine.cpp
  clang/lib/CodeGen/CodeGenFunction.h
  clang/test/CodeGenCoroutines/coro-awaiter-noinline-suspend.cpp
  clang/test/CodeGenCoroutines/pr56301.cpp

Index: clang/test/CodeGenCoroutines/pr56301.cpp
===================================================================
--- /dev/null
+++ clang/test/CodeGenCoroutines/pr56301.cpp
@@ -0,0 +1,85 @@
+// An end-to-end test to make sure things get processed correctly.
+// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -emit-llvm -o - %s -O3 | \
+// RUN:     FileCheck %s
+
+#include "Inputs/coroutine.h"
+
+struct SomeAwaitable {
+  // Resume the supplied handle once the awaitable becomes ready,
+  // returning a handle that should be resumed now for the sake of symmetric transfer.
+  // If the awaitable is already ready, return an empty handle without doing anything.
+  //
+  // Defined in another translation unit. Note that this may contain
+  // code that synchronizees with another thread.
+  std::coroutine_handle<> Register(std::coroutine_handle<>);
+};
+
+// Defined in another translation unit.
+void DidntSuspend();
+
+struct Awaiter {
+  SomeAwaitable&& awaitable;
+  bool suspended;
+
+  bool await_ready() { return false; }
+
+  std::coroutine_handle<> await_suspend(const std::coroutine_handle<> h) {
+    // Assume we will suspend unless proven otherwise below. We must do
+    // this *before* calling Register, since we may be destroyed by another
+    // thread asynchronously as soon as we have registered.
+    suspended = true;
+
+    // Attempt to hand off responsibility for resuming/destroying the coroutine.
+    const auto to_resume = awaitable.Register(h);
+
+    if (!to_resume) {
+      // The awaitable is already ready. In this case we know that Register didn't
+      // hand off responsibility for the coroutine. So record the fact that we didn't
+      // actually suspend, and tell the compiler to resume us inline.
+      suspended = false;
+      return h;
+    }
+
+    // Resume whatever Register wants us to resume.
+    return to_resume;
+  }
+
+  void await_resume() {
+    // If we didn't suspend, make note of that fact.
+    if (!suspended) {
+      DidntSuspend();
+    }
+  }
+};
+
+struct MyTask{
+  struct promise_type {
+    MyTask get_return_object() { return {}; }
+    std::suspend_never initial_suspend() { return {}; }
+    std::suspend_always final_suspend() noexcept { return {}; }
+    void unhandled_exception();
+
+    Awaiter await_transform(SomeAwaitable&& awaitable) {
+      return Awaiter{static_cast<SomeAwaitable&&>(awaitable)};
+    }
+  };
+};
+
+MyTask FooBar() {
+  co_await SomeAwaitable();
+}
+
+// CHECK-LABEL: @_Z6FooBarv
+// CHECK: %[[to_resume:.*]] = {{.*}}call ptr @_ZN13SomeAwaitable8RegisterESt16coroutine_handleIvE
+// CHECK-NEXT: %[[to_bool:.*]] = icmp eq ptr %[[to_resume]], null
+// CHECK-NEXT: br i1 %[[to_bool]], label %[[then:.*]], label %[[else:.*]]
+
+// CHECK: [[then]]:
+// We only access the coroutine frame conditionally as the sources did.
+// CHECK:   store i8 0,
+// CHECK-NEXT: br label %[[else]]
+
+// CHECK: [[else]]:
+// No more access to the coroutine frame until suspended.
+// CHECK-NOT: store
+// CHECK: }
Index: clang/test/CodeGenCoroutines/coro-awaiter-noinline-suspend.cpp
===================================================================
--- /dev/null
+++ clang/test/CodeGenCoroutines/coro-awaiter-noinline-suspend.cpp
@@ -0,0 +1,207 @@
+// Tests that we can mark await-suspend as noinline correctly.
+//
+// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -emit-llvm -o - %s \
+// RUN:     -disable-llvm-passes | FileCheck %s
+
+#include "Inputs/coroutine.h"
+
+struct Task {
+  struct promise_type {
+    struct FinalAwaiter {
+      bool await_ready() const noexcept { return false; }
+      template <typename PromiseType>
+      std::coroutine_handle<> await_suspend(std::coroutine_handle<PromiseType> h) noexcept {
+        return h.promise().continuation;
+      }
+      void await_resume() noexcept {}
+    };
+
+    Task get_return_object() noexcept {
+      return std::coroutine_handle<promise_type>::from_promise(*this);
+    }
+
+    std::suspend_always initial_suspend() noexcept { return {}; }
+    FinalAwaiter final_suspend() noexcept { return {}; }
+    void unhandled_exception() noexcept {}
+    void return_void() noexcept {}
+
+    std::coroutine_handle<> continuation;
+  };
+
+  Task(std::coroutine_handle<promise_type> handle);
+  ~Task();
+
+private:
+  std::coroutine_handle<promise_type> handle;
+};
+
+struct StatefulAwaiter {
+    int value;
+    bool await_ready() const noexcept { return false; }
+    template <typename PromiseType>
+    void await_suspend(std::coroutine_handle<PromiseType> h) noexcept {}
+    void await_resume() noexcept {}
+};
+
+typedef std::suspend_always NoStateAwaiter;
+using AnotherStatefulAwaiter = StatefulAwaiter;
+
+template <class T>
+struct TemplatedAwaiter {
+    T value;
+    bool await_ready() const noexcept { return false; }
+    template <typename PromiseType>
+    void await_suspend(std::coroutine_handle<PromiseType> h) noexcept {}
+    void await_resume() noexcept {}
+};
+
+
+class Awaitable {};
+StatefulAwaiter operator co_await(Awaitable) {
+  return StatefulAwaiter{};
+}
+
+StatefulAwaiter GlobalAwaiter;
+class Awaitable2 {};
+StatefulAwaiter& operator co_await(Awaitable2) {
+  return GlobalAwaiter;
+}
+
+Task testing() {
+    co_await std::suspend_always{};
+    co_await StatefulAwaiter{};
+    co_await AnotherStatefulAwaiter{};
+    
+    // Test lvalue case.
+    StatefulAwaiter awaiter;
+    co_await awaiter;
+
+    // The explicit call to await_suspend is not considered suspended.
+    awaiter.await_suspend(std::coroutine_handle<void>::from_address(nullptr));
+
+    co_await TemplatedAwaiter<int>{};
+    TemplatedAwaiter<int> TemplatedAwaiterInstace;
+    co_await TemplatedAwaiterInstace;
+
+    co_await Awaitable{};
+    co_await Awaitable2{};
+}
+
+// CHECK-LABEL: @_Z7testingv
+
+// Check `co_await __promise__.initial_suspend();` Since it returns std::suspend_always,
+// which is an empty class, we shouldn't generate optimization blocker for it.
+// CHECK: call token @llvm.coro.save
+// CHECK: call void @_ZNSt14suspend_always13await_suspendESt16coroutine_handleIvE{{.*}}#[[NORMAL_ATTR:[0-9]+]]
+
+// Check the `co_await std::suspend_always{};` expression. We shouldn't emit the optimization
+// blocker for it since it is an empty class.
+// CHECK: call token @llvm.coro.save
+// CHECK: call void @_ZNSt14suspend_always13await_suspendESt16coroutine_handleIvE{{.*}}#[[NORMAL_ATTR]]
+
+// Check `co_await StatefulAwaiter{};`. We need to emit the optimization blocker since
+// the awaiter is not empty.
+// CHECK: call token @llvm.coro.save
+// CHECK: call void @_ZN15StatefulAwaiter13await_suspendIN4Task12promise_typeEEEvSt16coroutine_handleIT_E{{.*}}#[[NOINLINE_ATTR:[0-9]+]]
+
+// Check `co_await AnotherStatefulAwaiter{};` to make sure that we can handle TypedefTypes.
+// CHECK: call token @llvm.coro.save
+// CHECK: call void @_ZN15StatefulAwaiter13await_suspendIN4Task12promise_typeEEEvSt16coroutine_handleIT_E{{.*}}#[[NOINLINE_ATTR]]
+
+// Check `co_await awaiter;` to make sure we can handle lvalue cases.
+// CHECK: call token @llvm.coro.save
+// CHECK: call void @_ZN15StatefulAwaiter13await_suspendIN4Task12promise_typeEEEvSt16coroutine_handleIT_E{{.*}}#[[NOINLINE_ATTR]]
+
+// Check `awaiter.await_suspend(...)` to make sure the explicit call the await_suspend won't be marked as noinline
+// CHECK: call void @_ZN15StatefulAwaiter13await_suspendIvEEvSt16coroutine_handleIT_E{{.*}}#[[NORMAL_ATTR]]
+
+// Check `co_await TemplatedAwaiter<int>{};` to make sure we can handle specialized template
+// type.
+// CHECK: call token @llvm.coro.save
+// CHECK: call void @_ZN16TemplatedAwaiterIiE13await_suspendIN4Task12promise_typeEEEvSt16coroutine_handleIT_E{{.*}}#[[NOINLINE_ATTR]]
+
+// Check `co_await TemplatedAwaiterInstace;` to make sure we can handle the lvalue from
+// specialized template type.
+// CHECK: call token @llvm.coro.save
+// CHECK: call void @_ZN16TemplatedAwaiterIiE13await_suspendIN4Task12promise_typeEEEvSt16coroutine_handleIT_E{{.*}}#[[NOINLINE_ATTR]]
+
+// Check `co_await Awaitable{};` to make sure we can handle awaiter returned by
+// `operator co_await`;
+// CHECK: call token @llvm.coro.save
+// CHECK: call void @_ZN15StatefulAwaiter13await_suspendIN4Task12promise_typeEEEvSt16coroutine_handleIT_E{{.*}}#[[NOINLINE_ATTR]]
+
+// Check `co_await Awaitable2{};` to make sure we can handle awaiter returned by
+// `operator co_await` which returns a reference;
+// CHECK: call token @llvm.coro.save
+// CHECK: call void @_ZN15StatefulAwaiter13await_suspendIN4Task12promise_typeEEEvSt16coroutine_handleIT_E{{.*}}#[[NOINLINE_ATTR]]
+
+// Check `co_await __promise__.final_suspend();`. We don't emit an blocker here since it is
+// empty.
+// CHECK: call token @llvm.coro.save
+// CHECK: call ptr @_ZN4Task12promise_type12FinalAwaiter13await_suspendIS0_EESt16coroutine_handleIvES3_IT_E{{.*}}#[[NORMAL_ATTR]]
+
+struct AwaitTransformTask {
+  struct promise_type {
+    struct FinalAwaiter {
+      bool await_ready() const noexcept { return false; }
+      template <typename PromiseType>
+      std::coroutine_handle<> await_suspend(std::coroutine_handle<PromiseType> h) noexcept {
+        return h.promise().continuation;
+      }
+      void await_resume() noexcept {}
+    };
+
+    AwaitTransformTask get_return_object() noexcept {
+      return std::coroutine_handle<promise_type>::from_promise(*this);
+    }
+
+    std::suspend_always initial_suspend() noexcept { return {}; }
+    FinalAwaiter final_suspend() noexcept { return {}; }
+    void unhandled_exception() noexcept {}
+    void return_void() noexcept {}
+
+    template <typename Awaitable>
+    auto await_transform(Awaitable &&awaitable) {
+      return awaitable;
+    }
+
+    std::coroutine_handle<> continuation;
+  };
+
+  AwaitTransformTask(std::coroutine_handle<promise_type> handle);
+  ~AwaitTransformTask();
+
+private:
+  std::coroutine_handle<promise_type> handle;
+};
+
+struct awaitableWithGetAwaiter {
+  bool await_ready() const noexcept { return false; }
+  template <typename PromiseType>
+  void await_suspend(std::coroutine_handle<PromiseType> h) noexcept {}
+  void await_resume() noexcept {}
+};
+
+AwaitTransformTask testingWithAwaitTransform() {
+  co_await awaitableWithGetAwaiter{};
+}
+
+// CHECK-LABEL: @_Z25testingWithAwaitTransformv
+
+// Init suspend
+// CHECK: call token @llvm.coro.save
+// CHECK-NOT: call void @llvm.coro.opt.blocker(
+// CHECK: call void @_ZNSt14suspend_always13await_suspendESt16coroutine_handleIvE{{.*}}#[[NORMAL_ATTR]]
+
+// Check `co_await awaitableWithGetAwaiter{};`.
+// CHECK: call token @llvm.coro.save
+// CHECK-NOT: call void @llvm.coro.opt.blocker(
+// Check call void @_ZN23awaitableWithGetAwaiter13await_suspendIN18AwaitTransformTask12promise_typeEEEvSt16coroutine_handleIT_E{{.*}}#[[NORMAL_ATTR]]
+
+// Final suspend
+// CHECK: call token @llvm.coro.save
+// CHECK-NOT: call void @llvm.coro.opt.blocker(
+// CHECK: call ptr @_ZN18AwaitTransformTask12promise_type12FinalAwaiter13await_suspendIS0_EESt16coroutine_handleIvES3_IT_E{{.*}}#[[NORMAL_ATTR]]
+
+// CHECK-NOT: attributes #[[NORMAL_ATTR]] = noinline
+// CHECK: attributes #[[NOINLINE_ATTR]] = {{.*}}noinline
Index: clang/lib/CodeGen/CodeGenFunction.h
===================================================================
--- clang/lib/CodeGen/CodeGenFunction.h
+++ clang/lib/CodeGen/CodeGenFunction.h
@@ -334,6 +334,7 @@
   struct CGCoroInfo {
     std::unique_ptr<CGCoroData> Data;
     bool InSuspendBlock = false;
+    bool MayCoroHandleEscape = false;
     CGCoroInfo();
     ~CGCoroInfo();
   };
@@ -347,6 +348,10 @@
     return isCoroutine() && CurCoro.InSuspendBlock;
   }
 
+  bool mayCoroHandleEscape() const {
+    return isCoroutine() && CurCoro.MayCoroHandleEscape;
+  }
+
   /// CurGD - The GlobalDecl for the current function being compiled.
   GlobalDecl CurGD;
 
Index: clang/lib/CodeGen/CGCoroutine.cpp
===================================================================
--- clang/lib/CodeGen/CGCoroutine.cpp
+++ clang/lib/CodeGen/CGCoroutine.cpp
@@ -139,6 +139,36 @@
   return true;
 }
 
+/// Return true when the coroutine handle may escape from the await-suspend
+/// (`awaiter.await_suspend(std::coroutine_handle)` expression).
+/// Return false only when the coroutine wouldn't escape in the await-suspend
+/// for sure.
+///
+/// While it is always safe to return true, return falses can bring better
+/// performances.
+///
+/// See https://github.com/llvm/llvm-project/issues/56301 and
+/// https://reviews.llvm.org/D157070 for the example and the full discussion.
+///
+/// FIXME: It will be much better to perform such analysis in the middle end.
+/// See the comments in `CodeGenFunction::EmitCall` for example.
+static bool MayCoroHandleEscape(CoroutineSuspendExpr const &S) {
+  CXXRecordDecl *Awaiter =
+      S.getCommonExpr()->getType().getNonReferenceType()->getAsCXXRecordDecl();
+
+  // Return false conservatively even if the underlying type is a record type.
+  if (!Awaiter)
+    return true;
+
+  // In case the awaiter type is empty, the suspend wouldn't leak the coroutine
+  // handle.
+  //
+  // TODO: We can improve this by looking into the implementation of
+  // await-suspend and see if the coroutine handle is passed to foreign
+  // functions.
+  return !Awaiter->field_empty();
+}
+
 // Emit suspend expression which roughly looks like:
 //
 //   auto && x = CommonExpr();
@@ -199,8 +229,11 @@
   auto *SaveCall = Builder.CreateCall(CoroSave, {NullPtr});
 
   CGF.CurCoro.InSuspendBlock = true;
+  CGF.CurCoro.MayCoroHandleEscape = MayCoroHandleEscape(S);
   auto *SuspendRet = CGF.EmitScalarExpr(S.getSuspendExpr());
   CGF.CurCoro.InSuspendBlock = false;
+  CGF.CurCoro.MayCoroHandleEscape = false;
+
   if (SuspendRet != nullptr && SuspendRet->getType()->isIntegerTy(1)) {
     // Veto suspension if requested by bool returning await_suspend.
     BasicBlock *RealSuspendBlock =
Index: clang/lib/CodeGen/CGCall.cpp
===================================================================
--- clang/lib/CodeGen/CGCall.cpp
+++ clang/lib/CodeGen/CGCall.cpp
@@ -5484,6 +5484,30 @@
         Attrs.addFnAttribute(getLLVMContext(), llvm::Attribute::AlwaysInline);
   }
 
+  // The await_suspend call performed by co_await is essentially asynchronous
+  // to the execution of the coroutine. Inlining it normally into an unsplit
+  // coroutine can cause miscompilation because the coroutine CFG misrepresents
+  // the true control flow of the program: things that happen in the
+  // await_suspend are not guaranteed to happen prior to the resumption of the
+  // coroutine, and things that happen after the resumption of the coroutine
+  // (including its exit and the potential deallocation of the coroutine frame)
+  // are not guaranteed to happen only after the end of await_suspend.
+  //
+  // The short-term solution to this problem is to mark the call as uninlinable.
+  // But we don't want to do this if the call is known to be trivial, which is
+  // very common.
+  //
+  // The long-term solution may introduce patterns like:
+  //
+  //  call @llvm.coro.await_suspend(ptr %awaiter, ptr %handle,
+  //                                ptr @awaitSuspendFn)
+  //
+  // Then it is much easier to perform the safety analysis in the middle end.
+  // If it is safe to inline the call to awaitSuspend, we can replace it in the
+  // CoroEarly pass. Otherwise we could replace it in the CoroSplit pass.
+  if (inSuspendBlock() && mayCoroHandleEscape())
+    Attrs = Attrs.addFnAttribute(getLLVMContext(), llvm::Attribute::NoInline);
+
   // Disable inlining inside SEH __try blocks.
   if (isSEHTryScope()) {
     Attrs = Attrs.addFnAttribute(getLLVMContext(), llvm::Attribute::NoInline);
Index: clang/docs/ReleaseNotes.rst
===================================================================
--- clang/docs/ReleaseNotes.rst
+++ clang/docs/ReleaseNotes.rst
@@ -138,6 +138,10 @@
   class, which can result in miscompiles in some cases.
 - Fix crash on use of a variadic overloaded operator.
   (`#42535 <https://github.com/llvm/llvm-project/issues/42535>_`)
+- Fixed an issue that the conditional access to local variables of the awaiter
+  after leaking the coroutine handle in the await_suspend may be converted to
+  unconditional access incorrectly.
+  (`#56301 <https://github.com/llvm/llvm-project/issues/56301>`_)
 
 Bug Fixes to Compiler Builtins
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to