5chmidti updated this revision to Diff 518159.
5chmidti added a comment.

I took a little break, but here are the changes/fixes:

- moved the logic for variables referenced in captures into the visitor
  - short circuiting the `TraverseLambdaExpression` and using 
`TraverseLambdaCapture` to handle variable captures and the initialization fo 
init-captures
  - fixes both problems mentioned by nridge
- fix for immediately invoked lambda expressions
  - the selection of an IIL would mark the call operator as a referenced decl 
(fixed in `VisitDeclRefExpr`)
- fix condition in `CanExtractOutside` to allow an unselected parent that is a 
lambda expression
  - allows for the extraction of an initializer for init-capture variables
- block default arguments from prarameters of a lambda from being extracted to 
a function-local scope (fixed in `CanExtractOutside`)
- added more tests


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D141757/new/

https://reviews.llvm.org/D141757

Files:
  clang-tools-extra/clangd/refactor/tweaks/ExtractVariable.cpp
  clang-tools-extra/clangd/unittests/tweaks/ExtractVariableTests.cpp
  clang-tools-extra/docs/ReleaseNotes.rst

Index: clang-tools-extra/docs/ReleaseNotes.rst
===================================================================
--- clang-tools-extra/docs/ReleaseNotes.rst
+++ clang-tools-extra/docs/ReleaseNotes.rst
@@ -66,6 +66,11 @@
 Code completion
 ^^^^^^^^^^^^^^^
 
+Code actions
+^^^^^^^^^^^^
+
+- The extract variable tweak gained support for extracting lambda expressions to a variable.
+
 Signature help
 ^^^^^^^^^^^^^^
 
Index: clang-tools-extra/clangd/unittests/tweaks/ExtractVariableTests.cpp
===================================================================
--- clang-tools-extra/clangd/unittests/tweaks/ExtractVariableTests.cpp
+++ clang-tools-extra/clangd/unittests/tweaks/ExtractVariableTests.cpp
@@ -131,7 +131,43 @@
       goto label;
       label:
         a = [[1]];
-    }
+
+      // lambdas
+      [][[(){}]];
+
+      // lambdas: captures
+      int x = 0;
+      [ [[=]] ](){};
+      [ [[&]] ](){};
+      [ [[x]] ](){};
+      [ [[&x] ]](){};
+      [y = [[x]] ](){};
+      [ [[y = x]] ](){};
+
+      // lambdas: default args
+      [](int x = [[10]]){};
+      [](auto h = [i = [[ [](){} ]]](){}) {};
+      [](auto h = [[ [i = [](){}](){} ]]) {};
+
+      // lambdas: scope
+      if (int a = 1)
+            if ([[ [&](){ return a + 1; } ]]() == 4)
+              a = a + 1;
+
+      for (int c = 0; [[ [&]() { return c < b; } ]](); ++c) {
+      }
+      for (int c = 0; [[ [&]() { return c < b; } () ]]; ++c) {
+      }
+
+      // lambdas: scope with structured binding
+      struct Coordinates {
+        int x{};
+        int y{};
+      };
+      Coordinates c{};
+      if (const auto [x, y] = c; x > y)
+        auto f = [[ [&]() { return x + y; } ]];
+      }
   )cpp";
   EXPECT_UNAVAILABLE(UnavailableCases);
 
@@ -282,6 +318,209 @@
                  void f() {
                    auto placeholder = S(2) + S(3) + S(4); S x = S(1) + placeholder + S(5);
                  })cpp"},
+      // lambda expressions
+      {R"cpp(template <typename T> void f(T) {}
+                void f2() {
+                  f([[ [](){ return 42; }]]);
+                }
+                )cpp",
+       R"cpp(template <typename T> void f(T) {}
+                void f2() {
+                  auto placeholder = [](){ return 42; }; f( placeholder);
+                }
+                )cpp"},
+      {R"cpp(auto foo(int VarA) {
+                  return [VarA]() {
+                    return [[ [VarA, VarC = 42 + VarA](int VarB) { return VarA + VarB + VarC; }]];
+                  };
+                }
+                )cpp",
+       R"cpp(auto foo(int VarA) {
+                  return [VarA]() {
+                    auto placeholder = [VarA, VarC = 42 + VarA](int VarB) { return VarA + VarB + VarC; }; return  placeholder;
+                  };
+                }
+                )cpp"},
+      {R"cpp(template <typename T> void f(T) {}
+                void f2(int var) {
+                  f([[ [&var](){ auto internal_val = 42; return var + internal_val; }]]);
+                }
+                )cpp",
+       R"cpp(template <typename T> void f(T) {}
+                void f2(int var) {
+                  auto placeholder = [&var](){ auto internal_val = 42; return var + internal_val; }; f( placeholder);
+                }
+                )cpp"},
+      {R"cpp(template <typename T> void f(T) { }
+                struct A {
+                    void f2(int& var) {
+                        auto local_var = 42;
+                        f([[ [&var, &local_var, this]() {
+                            auto internal_val = 42;
+                            return var + local_var + internal_val + member;
+                        }]]);
+                    }
+
+                    int member = 42;
+};
+                )cpp",
+       R"cpp(template <typename T> void f(T) { }
+                struct A {
+                    void f2(int& var) {
+                        auto local_var = 42;
+                        auto placeholder = [&var, &local_var, this]() {
+                            auto internal_val = 42;
+                            return var + local_var + internal_val + member;
+                        }; f( placeholder);
+                    }
+
+                    int member = 42;
+};
+                )cpp"},
+      {R"cpp(void f() { auto x = [[ [](){ return 42; }]]; })cpp",
+       R"cpp(void f() { auto placeholder = [](){ return 42; }; auto x =  placeholder; })cpp"},
+      {R"cpp(
+        template <typename T>
+        auto sink(T f) { return f(); }
+        int bar() {
+          return sink([[ []() { return 42; }]]);
+        }
+       )cpp",
+       R"cpp(
+        template <typename T>
+        auto sink(T f) { return f(); }
+        int bar() {
+          auto placeholder = []() { return 42; }; return sink( placeholder);
+        }
+       )cpp"},
+      {R"cpp(
+        int main() {
+          if (int a = 1) {
+            if ([[ [&](){ return a + 1; } ]]() == 4)
+              a = a + 1;
+          }
+        })cpp",
+       R"cpp(
+        int main() {
+          if (int a = 1) {
+            auto placeholder = [&](){ return a + 1; }; if ( placeholder () == 4)
+              a = a + 1;
+          }
+        })cpp"},
+      {R"cpp(
+        int main() {
+          if (int a = 1) {
+            if ([[ [&](){ return a + 1; }() ]] == 4)
+              a = a + 1;
+          }
+        })cpp",
+       R"cpp(
+        int main() {
+          if (int a = 1) {
+            auto placeholder = [&](){ return a + 1; }(); if ( placeholder  == 4)
+              a = a + 1;
+          }
+        })cpp"},
+      {R"cpp(
+        template <typename T>
+        auto call(T t) { return t(); }
+
+        int main() {
+          return [[ call([](){ int a = 1; return a + 1; }) ]] + 5;
+        })cpp",
+       R"cpp(
+        template <typename T>
+        auto call(T t) { return t(); }
+
+        int main() {
+          auto placeholder = call([](){ int a = 1; return a + 1; }); return  placeholder  + 5;
+        })cpp"},
+      {R"cpp(
+        class Foo {
+          int bar() {
+            return [f = [[ [this](int g) { return g + x; } ]] ]() { return 42; }();
+          }
+          int x;
+        };
+      )cpp",
+       R"cpp(
+        class Foo {
+          int bar() {
+            auto placeholder = [this](int g) { return g + x; }; return [f =  placeholder  ]() { return 42; }();
+          }
+          int x;
+        };
+      )cpp"},
+      {R"cpp(
+        int main() {
+          return [[ []() { return 42; }() ]];
+        })cpp",
+       R"cpp(
+        int main() {
+          auto placeholder = []() { return 42; }(); return  placeholder ;
+        })cpp"},
+      {R"cpp(
+        template <typename ...Ts>
+        void foo(Ts ...args) {
+          auto x = [[ [&args...]() {} ]];
+        }
+      )cpp",
+       R"cpp(
+        template <typename ...Ts>
+        void foo(Ts ...args) {
+          auto placeholder = [&args...]() {}; auto x =  placeholder ;
+        }
+      )cpp"},
+      {R"cpp(
+        struct Coordinates {
+          int x{};
+          int y{};
+        };
+
+        int main() {
+          Coordinates c = {};
+          const auto [x, y] = c;
+          auto f = [[ [&]() { return x + y; } ]];
+        }
+        )cpp",
+       R"cpp(
+        struct Coordinates {
+          int x{};
+          int y{};
+        };
+
+        int main() {
+          Coordinates c = {};
+          const auto [x, y] = c;
+          auto placeholder = [&]() { return x + y; }; auto f =  placeholder ;
+        }
+        )cpp"},
+      {R"cpp(
+        struct Coordinates {
+          int x{};
+          int y{};
+        };
+
+        int main() {
+          Coordinates c = {};
+          if (const auto [x, y] = c; x > y) {
+            auto f = [[ [&]() { return x + y; } ]];
+          }
+        }
+        )cpp",
+       R"cpp(
+        struct Coordinates {
+          int x{};
+          int y{};
+        };
+
+        int main() {
+          Coordinates c = {};
+          if (const auto [x, y] = c; x > y) {
+            auto placeholder = [&]() { return x + y; }; auto f =  placeholder ;
+          }
+        }
+        )cpp"},
       // Don't try to analyze across macro boundaries
       // FIXME: it'd be nice to do this someday (in a safe way)
       {R"cpp(#define ECHO(X) X
Index: clang-tools-extra/clangd/refactor/tweaks/ExtractVariable.cpp
===================================================================
--- clang-tools-extra/clangd/refactor/tweaks/ExtractVariable.cpp
+++ clang-tools-extra/clangd/refactor/tweaks/ExtractVariable.cpp
@@ -12,8 +12,11 @@
 #include "SourceCode.h"
 #include "refactor/Tweak.h"
 #include "clang/AST/ASTContext.h"
+#include "clang/AST/Decl.h"
+#include "clang/AST/DeclCXX.h"
 #include "clang/AST/Expr.h"
 #include "clang/AST/ExprCXX.h"
+#include "clang/AST/LambdaCapture.h"
 #include "clang/AST/OperationKinds.h"
 #include "clang/AST/RecursiveASTVisitor.h"
 #include "clang/AST/Stmt.h"
@@ -74,10 +77,29 @@
   public:
     std::vector<Decl *> ReferencedDecls;
     bool VisitDeclRefExpr(DeclRefExpr *DeclRef) { // NOLINT
+      // Stop the call operator of lambdas from being marked as a referenced
+      // DeclRefExpr in immediately invoked lambdas.
+      if (const auto *const Method =
+              llvm::dyn_cast<CXXMethodDecl>(DeclRef->getDecl());
+          Method != nullptr && Method->getParent()->isLambda()) {
+        return true;
+      }
       ReferencedDecls.push_back(DeclRef->getDecl());
       return true;
     }
+
+    // Local variables declared inside of the selected lambda cannot go out of
+    // scope. The DeclRefExprs that are important are the variables captured and
+    // the DeclRefExprs inside the initializers of init-capture variables.
+    bool TraverseLambdaExpr(LambdaExpr *LExpr) {
+      for (const auto &[Capture, Initializer] :
+           llvm::zip(LExpr->captures(), LExpr->capture_inits())) {
+        TraverseLambdaCapture(LExpr, &Capture, Initializer);
+      }
+      return true;
+    }
   };
+
   FindDeclRefsVisitor Visitor;
   Visitor.TraverseStmt(const_cast<Stmt *>(cast<Stmt>(Expr)));
   return Visitor.ReferencedDecls;
@@ -152,10 +174,23 @@
   auto CanExtractOutside =
       [](const SelectionTree::Node *InsertionPoint) -> bool {
     if (const clang::Stmt *Stmt = InsertionPoint->ASTNode.get<clang::Stmt>()) {
-      // Allow all expressions except LambdaExpr since we don't want to extract
-      // from the captures/default arguments of a lambda
-      if (isa<clang::Expr>(Stmt))
-        return !isa<LambdaExpr>(Stmt);
+      // Allow all expressions except partial LambdaExpr selections since we
+      // don't want to extract from the captures/default arguments of a lambda
+      if (isa<clang::Expr>(Stmt)) {
+        // Do not allow extraction from the initializer of a defaulted parameter
+        // (of a lambda) to a local variable.
+        if (InsertionPoint->Parent->ASTNode.get<ParmVarDecl>() != nullptr) {
+          return false;
+        }
+
+        // Allow expressions, but only allow completely selected lambda
+        // expressions or unselected lambda expressions that are the parent of
+        // the originally selected node, not partially selected lambda
+        // expressions.
+        return !isa<LambdaExpr>(Stmt) ||
+               InsertionPoint->Selected != SelectionTree::Partial;
+      }
+
       // We don't yet allow extraction from switch/case stmt as we would need to
       // jump over the switch stmt even if there is a CompoundStmt inside the
       // switch. And there are other Stmts which we don't care about (e.g.
@@ -240,7 +275,7 @@
     SelectedOperands.clear();
 
     if (const BinaryOperator *Op =
-        llvm::dyn_cast_or_null<BinaryOperator>(N.ASTNode.get<Expr>())) {
+            llvm::dyn_cast_or_null<BinaryOperator>(N.ASTNode.get<Expr>())) {
       Kind = Op->getOpcode();
       ExprLoc = Op->getExprLoc();
       SelectedOperands = N.Children;
@@ -255,7 +290,7 @@
       Kind = BinaryOperator::getOverloadedOpcode(Op->getOperator());
       ExprLoc = Op->getExprLoc();
       // Not all children are args, there's also the callee (operator).
-      for (const auto* Child : N.Children) {
+      for (const auto *Child : N.Children) {
         const Expr *E = Child->ASTNode.get<Expr>();
         assert(E && "callee and args should be Exprs!");
         if (E == Op->getArg(0) || E == Op->getArg(1))
@@ -376,15 +411,15 @@
   if (llvm::isa<SwitchCase>(Outer))
     return true;
   // Control flow statements use condition etc, but not the body.
-  if (const auto* WS = llvm::dyn_cast<WhileStmt>(Outer))
+  if (const auto *WS = llvm::dyn_cast<WhileStmt>(Outer))
     return Inner == WS->getBody();
-  if (const auto* DS = llvm::dyn_cast<DoStmt>(Outer))
+  if (const auto *DS = llvm::dyn_cast<DoStmt>(Outer))
     return Inner == DS->getBody();
-  if (const auto* FS = llvm::dyn_cast<ForStmt>(Outer))
+  if (const auto *FS = llvm::dyn_cast<ForStmt>(Outer))
     return Inner == FS->getBody();
-  if (const auto* FS = llvm::dyn_cast<CXXForRangeStmt>(Outer))
+  if (const auto *FS = llvm::dyn_cast<CXXForRangeStmt>(Outer))
     return Inner == FS->getBody();
-  if (const auto* IS = llvm::dyn_cast<IfStmt>(Outer))
+  if (const auto *IS = llvm::dyn_cast<IfStmt>(Outer))
     return Inner == IS->getThen() || Inner == IS->getElse();
   // Assume all other cases may be actual expressions.
   // This includes the important case of subexpressions (where Outer is Expr).
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to