tbaeder updated this revision to Diff 522587.

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D150111/new/

https://reviews.llvm.org/D150111

Files:
  clang/lib/AST/Interp/ByteCodeEmitter.cpp
  clang/lib/AST/Interp/ByteCodeStmtGen.cpp
  clang/lib/AST/Interp/ByteCodeStmtGen.h
  clang/lib/AST/Interp/Function.h
  clang/lib/AST/Interp/Interp.h
  clang/test/AST/Interp/lambda.cpp

Index: clang/test/AST/Interp/lambda.cpp
===================================================================
--- clang/test/AST/Interp/lambda.cpp
+++ clang/test/AST/Interp/lambda.cpp
@@ -107,3 +107,58 @@
   static_assert(foo() == 1); // expected-error {{not an integral constant expression}}
 }
 
+namespace StaticInvoker {
+  constexpr int sv1(int i) {
+    auto l = []() { return 12; };
+    int (*fp)() = l;
+    return fp();
+  }
+  static_assert(sv1(12) == 12);
+
+  constexpr int sv2(int i) {
+    auto l = [](int m, float f, void *A) { return m; };
+    int (*fp)(int, float, void*) = l;
+    return fp(i, 4.0f, nullptr);
+  }
+  static_assert(sv2(12) == 12);
+
+  constexpr int sv3(int i) {
+    auto l = [](int m, const int &n) { return m; };
+    int (*fp)(int, const int &) = l;
+    return fp(i, 3);
+  }
+  static_assert(sv3(12) == 12);
+
+  constexpr int sv4(int i) {
+    auto l = [](int &m) { return m; };
+    int (*fp)(int&) = l;
+    return fp(i);
+  }
+  static_assert(sv4(12) == 12);
+
+
+
+  /// FIXME: This is broken for lambda-unrelated reasons.
+#if 0
+  constexpr int sv5(int i) {
+    struct F { int a; float f; };
+    auto l = [](int m, F f) { return m; };
+    int (*fp)(int, F) = l;
+    return fp(i, F{12, 14.0});
+  }
+  static_assert(sv5(12) == 12);
+#endif
+
+  constexpr int sv6(int i) {
+    struct F { int a;
+      constexpr F(int a) : a(a) {}
+    };
+
+    auto l = [](int m) { return F(12); };
+    F (*fp)(int) = l;
+    F f = fp(i);
+
+    return fp(i).a;
+  }
+  static_assert(sv6(12) == 12);
+}
Index: clang/lib/AST/Interp/Interp.h
===================================================================
--- clang/lib/AST/Interp/Interp.h
+++ clang/lib/AST/Interp/Interp.h
@@ -1706,8 +1706,16 @@
 
     const Pointer &ThisPtr = S.Stk.peek<Pointer>(ThisOffset);
 
-    if (!CheckInvoke(S, OpPC, ThisPtr))
-      return false;
+    // If the current function is a lambda static invoker and
+    // the function we're about to call is a lambda call operator,
+    // skip the CheckInvoke, since the ThisPtr is a null pointer
+    // anyway.
+    if (!(S.Current->getFunction() &&
+          S.Current->getFunction()->isLambdaStaticInvoker() &&
+          Func->isLambdaCallOperator())) {
+      if (!CheckInvoke(S, OpPC, ThisPtr))
+        return false;
+    }
 
     if (S.checkingPotentialConstantExpression())
       return false;
Index: clang/lib/AST/Interp/Function.h
===================================================================
--- clang/lib/AST/Interp/Function.h
+++ clang/lib/AST/Interp/Function.h
@@ -17,6 +17,7 @@
 
 #include "Pointer.h"
 #include "Source.h"
+#include "clang/AST/ASTLambda.h"
 #include "clang/AST/Decl.h"
 #include "llvm/Support/raw_ostream.h"
 
@@ -65,7 +66,7 @@
 /// the argument values need to be preceeded by a Pointer for the This object.
 ///
 /// If the function uses Return Value Optimization, the arguments (and
-/// potentially the This pointer) need to be proceeded by a Pointer pointing
+/// potentially the This pointer) need to be preceeded by a Pointer pointing
 /// to the location to construct the returned value.
 ///
 /// After the function has been called, it will remove all arguments,
@@ -127,7 +128,7 @@
   SourceInfo getSource(CodePtr PC) const;
 
   /// Checks if the function is valid to call in constexpr.
-  bool isConstexpr() const { return IsValid; }
+  bool isConstexpr() const { return IsValid || isLambdaStaticInvoker(); }
 
   /// Checks if the function is virtual.
   bool isVirtual() const;
@@ -144,6 +145,22 @@
     return nullptr;
   }
 
+  /// Returns whether this function is a lambda static invoker,
+  /// which we generate custom byte code for.
+  bool isLambdaStaticInvoker() const {
+    if (const auto *MD = dyn_cast<CXXMethodDecl>(F))
+      return MD->isLambdaStaticInvoker();
+    return false;
+  }
+
+  /// Returns whether this function is the call operator
+  /// of a lambda record decl.
+  bool isLambdaCallOperator() const {
+    if (const auto *MD = dyn_cast<CXXMethodDecl>(F))
+      return clang::isLambdaCallOperator(MD);
+    return false;
+  }
+
   /// Checks if the function is fully done compiling.
   bool isFullyCompiled() const { return IsFullyCompiled; }
 
Index: clang/lib/AST/Interp/ByteCodeStmtGen.h
===================================================================
--- clang/lib/AST/Interp/ByteCodeStmtGen.h
+++ clang/lib/AST/Interp/ByteCodeStmtGen.h
@@ -68,6 +68,8 @@
   bool visitCaseStmt(const CaseStmt *S);
   bool visitDefaultStmt(const DefaultStmt *S);
 
+  bool emitLambdaStaticInvokerBody(const CXXMethodDecl *MD);
+
   /// Type of the expression returned by the function.
   std::optional<PrimType> ReturnType;
 
Index: clang/lib/AST/Interp/ByteCodeStmtGen.cpp
===================================================================
--- clang/lib/AST/Interp/ByteCodeStmtGen.cpp
+++ clang/lib/AST/Interp/ByteCodeStmtGen.cpp
@@ -89,11 +89,67 @@
 } // namespace interp
 } // namespace clang
 
+template <class Emitter>
+bool ByteCodeStmtGen<Emitter>::emitLambdaStaticInvokerBody(
+    const CXXMethodDecl *MD) {
+  assert(MD->isLambdaStaticInvoker());
+  assert(MD->hasBody());
+  assert(cast<CompoundStmt>(MD->getBody())->body_empty());
+
+  const CXXRecordDecl *ClosureClass = MD->getParent();
+  const CXXMethodDecl *LambdaCallOp = ClosureClass->getLambdaCallOperator();
+  assert(ClosureClass->captures_begin() == ClosureClass->captures_end());
+  const Function *Func = this->getFunction(LambdaCallOp);
+  if (!Func)
+    return false;
+  assert(Func->hasThisPointer());
+  assert(Func->getNumParams() == (MD->getNumParams() + 1 + Func->hasRVO()));
+
+  if (Func->hasRVO()) {
+    if (!this->emitRVOPtr(MD))
+      return false;
+  }
+
+  // The lambda call operator needs an instance pointer, but we don't have
+  // one here, and we don't need one either because the lambda cannot have
+  // any captures, as verified above. Emit a null pointer. This is then
+  // special-cased when interpreting to not emit any misleading diagnostics.
+  if (!this->emitNullPtr(MD))
+    return false;
+
+  // Forward all arguments from the static invoker to the lambda call operator.
+  for (const ParmVarDecl *PVD : MD->parameters()) {
+    auto It = this->Params.find(PVD);
+    assert(It != this->Params.end());
+
+    // We do the lvalue-to-rvalue conversion manually here, so no need
+    // to care about references.
+    PrimType ParamType = this->classify(PVD->getType()).value_or(PT_Ptr);
+    if (!this->emitGetParam(ParamType, It->second, MD))
+      return false;
+  }
+
+  if (!this->emitCall(Func, LambdaCallOp))
+    return false;
+
+  this->emitCleanup();
+  if (ReturnType)
+    return this->emitRet(*ReturnType, MD);
+
+  // Nothing to do, since we emitted the RVO pointer above.
+  return this->emitRetVoid(MD);
+}
+
 template <class Emitter>
 bool ByteCodeStmtGen<Emitter>::visitFunc(const FunctionDecl *F) {
   // Classify the return type.
   ReturnType = this->classify(F->getReturnType());
 
+  // Emit custom code if this is a lambda static invoker.
+  if (const auto *MD = dyn_cast<CXXMethodDecl>(F);
+      MD && MD->isLambdaStaticInvoker())
+    return this->emitLambdaStaticInvokerBody(MD);
+
   // Constructor. Set up field initializers.
   if (const auto Ctor = dyn_cast<CXXConstructorDecl>(F)) {
     const RecordDecl *RD = Ctor->getParent();
Index: clang/lib/AST/Interp/ByteCodeEmitter.cpp
===================================================================
--- clang/lib/AST/Interp/ByteCodeEmitter.cpp
+++ clang/lib/AST/Interp/ByteCodeEmitter.cpp
@@ -96,8 +96,15 @@
   if (!FuncDecl->isDefined())
     return Func;
 
+  // Lambda static invokers are a special case that we emit custom code for.
+  bool IsEligibleForCompilation = false;
+  if (const auto *MD = dyn_cast<CXXMethodDecl>(FuncDecl))
+    IsEligibleForCompilation = MD->isLambdaStaticInvoker();
+  if (!IsEligibleForCompilation)
+    IsEligibleForCompilation = FuncDecl->isConstexpr();
+
   // Compile the function body.
-  if (!FuncDecl->isConstexpr() || !visitFunc(FuncDecl)) {
+  if (!IsEligibleForCompilation || !visitFunc(FuncDecl)) {
     // Return a dummy function if compilation failed.
     if (BailLocation)
       return llvm::make_error<ByteCodeGenError>(*BailLocation);
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to