5chmidti updated this revision to Diff 477184.
5chmidti added a comment.

Fixup: rm added includes


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D138499/new/

https://reviews.llvm.org/D138499

Files:
  clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp
  clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp
  clang-tools-extra/docs/ReleaseNotes.rst

Index: clang-tools-extra/docs/ReleaseNotes.rst
===================================================================
--- clang-tools-extra/docs/ReleaseNotes.rst
+++ clang-tools-extra/docs/ReleaseNotes.rst
@@ -78,6 +78,9 @@
 Miscellaneous
 ^^^^^^^^^^^^^
 
+- The extract function tweak gained support for hoisting, i.e. returning decls declared
+  inside the selection that are used outside of the selection.
+
 Improvements to clang-doc
 -------------------------
 
Index: clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp
===================================================================
--- clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp
+++ clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp
@@ -30,8 +30,9 @@
   EXPECT_EQ(apply("auto lam = [](){ [[int x;]] }; "), "unavailable");
   // Partial statements aren't extracted.
   EXPECT_THAT(apply("int [[x = 0]];"), "unavailable");
-  // FIXME: Support hoisting.
-  EXPECT_THAT(apply(" [[int a = 5;]] a++; "), "unavailable");
+
+  // Extract regions that require hoisting
+  EXPECT_THAT(apply(" [[int a = 5;]] a++; "), HasSubstr("extracted"));
 
   // Ensure that end of Zone and Beginning of PostZone being adjacent doesn't
   // lead to break being included in the extraction zone.
@@ -192,6 +193,202 @@
   EXPECT_EQ(apply(CompoundFailInput), "unavailable");
 }
 
+TEST_F(ExtractFunctionTest, Hoisting) {
+  std::string HoistingInput = R"cpp(
+    int foo() {
+      int a = 3;
+      [[int x = 39 + a;
+      ++x;
+      int y = x * 2;
+      int z = 4;]]
+      return x + y + z;
+    }
+  )cpp";
+  std::string HoistingOutput = R"cpp(
+    auto extracted(int &a) {
+int x = 39 + a;
+      ++x;
+      int y = x * 2;
+      int z = 4;
+return std::tuple{x, y, z};
+}
+int foo() {
+      int a = 3;
+      auto [x, y, z] = extracted(a);
+      return x + y + z;
+    }
+  )cpp";
+  EXPECT_EQ(apply(HoistingInput), HoistingOutput);
+
+  std::string HoistingInput2 = R"cpp(
+    int foo() {
+      int a{};
+      [[int b = a + 1;]]
+      return b;
+    }
+  )cpp";
+  std::string HoistingOutput2 = R"cpp(
+    int extracted(int &a) {
+int b = a + 1;
+return b;
+}
+int foo() {
+      int a{};
+      auto b = extracted(a);
+      return b;
+    }
+  )cpp";
+  EXPECT_EQ(apply(HoistingInput2), HoistingOutput2);
+
+  std::string HoistingInput3 = R"cpp(
+    int foo(int b) {
+      int a{};
+      if (b == 42) {
+        [[a = 123;
+        return a + b;]]
+      }
+      a = 456;
+      return a;
+    }
+  )cpp";
+  std::string HoistingOutput3 = R"cpp(
+    int extracted(int &b, int &a) {
+a = 123;
+        return a + b;
+}
+int foo(int b) {
+      int a{};
+      if (b == 42) {
+        return extracted(b, a);
+      }
+      a = 456;
+      return a;
+    }
+  )cpp";
+  EXPECT_EQ(apply(HoistingInput3), HoistingOutput3);
+
+  std::string HoistingInput4 = R"cpp(
+    struct A {
+      bool flag;
+      int val;
+    };
+    A bar();
+    int foo(int b) {
+      int a = 0;
+      [[auto [flag, val] = bar();
+      int c = 4;
+      val = c + a;]]
+      return a + b + c + val;
+    }
+  )cpp";
+  std::string HoistingOutput4 = R"cpp(
+    struct A {
+      bool flag;
+      int val;
+    };
+    A bar();
+    auto extracted(int &a) {
+auto [flag, val] = bar();
+      int c = 4;
+      val = c + a;
+return std::pair{val, c};
+}
+int foo(int b) {
+      int a = 0;
+      auto [val, c] = extracted(a);
+      return a + b + c + val;
+    }
+  )cpp";
+  EXPECT_EQ(apply(HoistingInput4), HoistingOutput4);
+}
+
+TEST_F(ExtractFunctionTest, HoistingCXX11) {
+  ExtraArgs.emplace_back("-std=c++11");
+  std::string HoistingInput = R"cpp(
+    int foo() {
+      int a = 3;
+      [[int x = 39 + a;
+      ++x;
+      int y = x * 2;
+      int z = 4;]]
+      return x + y + z;
+    }
+  )cpp";
+  EXPECT_THAT(apply(HoistingInput), HasSubstr("unavailable"));
+
+  std::string HoistingInput2 = R"cpp(
+    int foo() {
+      int a;
+      [[int b = a + 1;]]
+      return b;
+    }
+  )cpp";
+  std::string HoistingOutput2 = R"cpp(
+    int extracted(int &a) {
+int b = a + 1;
+return b;
+}
+int foo() {
+      int a;
+      auto b = extracted(a);
+      return b;
+    }
+  )cpp";
+  EXPECT_EQ(apply(HoistingInput2), HoistingOutput2);
+}
+
+TEST_F(ExtractFunctionTest, HoistingCXX14) {
+  ExtraArgs.emplace_back("-std=c++14");
+  std::string HoistingInput = R"cpp(
+    int foo() {
+      int a = 3;
+      [[int x = 39 + a;
+      ++x;
+      int y = x * 2;
+      int z = 4;]]
+      return x + y + z;
+    }
+  )cpp";
+  std::string HoistingOutput = R"cpp(
+    auto extracted(int &a) {
+int x = 39 + a;
+      ++x;
+      int y = x * 2;
+      int z = 4;
+return std::tuple{x, y, z};
+}
+int foo() {
+      int a = 3;
+      auto returned = extracted(a);
+auto x = std::get<0>(returned);
+auto y = std::get<1>(returned);
+auto z = std::get<2>(returned);
+      return x + y + z;
+    }
+  )cpp";
+  EXPECT_EQ(apply(HoistingInput), HoistingOutput);
+
+  std::string HoistingInput2 = R"cpp(
+    int foo() {
+      int a;
+      [[int b = a + 1;]]
+      return b;
+    }
+  )cpp";
+  std::string HoistingOutput2 = R"cpp(
+    int extracted(int &a) {
+int b = a + 1;
+return b;
+}
+int foo() {
+      int a;
+      auto b = extracted(a);
+      return b;
+    }
+  )cpp";
+  EXPECT_EQ(apply(HoistingInput2), HoistingOutput2);
+}
+
 TEST_F(ExtractFunctionTest, DifferentHeaderSourceTest) {
   Header = R"cpp(
     class SomeClass {
Index: clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp
===================================================================
--- clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp
+++ clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp
@@ -80,6 +80,13 @@
 
 using Node = SelectionTree::Node;
 
+struct HoistSetComparator {
+  bool operator()(const Decl *const Lhs, const Decl *const Rhs) const {
+    return Lhs->getLocation() < Rhs->getLocation();
+  }
+};
+using HoistSet = llvm::SmallSet<const NamedDecl *, 1, HoistSetComparator>;
+
 // ExtractionZone is the part of code that is being extracted.
 // EnclosingFunction is the function/method inside which the zone lies.
 // We split the file into 4 parts relative to extraction zone.
@@ -172,12 +179,13 @@
   // semicolon after the extraction.
   const Node *getLastRootStmt() const { return Parent->Children.back(); }
 
-  // Checks if declarations inside extraction zone are accessed afterwards.
+  // Checks if declarations inside extraction zone are accessed afterwards and
+  // adds these declarations to the returned set.
   //
   // This performs a partial AST traversal proportional to the size of the
   // enclosing function, so it is possibly expensive.
-  bool requiresHoisting(const SourceManager &SM,
-                        const HeuristicResolver *Resolver) const {
+  HoistSet getDeclsToHoist(const SourceManager &SM,
+                           const HeuristicResolver *Resolver) const {
     // First find all the declarations that happened inside extraction zone.
     llvm::SmallSet<const Decl *, 1> DeclsInExtZone;
     for (auto *RootStmt : RootStmts) {
@@ -192,29 +200,31 @@
     }
     // Early exit without performing expensive traversal below.
     if (DeclsInExtZone.empty())
-      return false;
-    // Then make sure they are not used outside the zone.
+      return {};
+    // Add any decl used after the selection to the returned set
+    HoistSet DeclsToHoist{};
     for (const auto *S : EnclosingFunction->getBody()->children()) {
       if (SM.isBeforeInTranslationUnit(S->getSourceRange().getEnd(),
                                        ZoneRange.getEnd()))
         continue;
-      bool HasPostUse = false;
       findExplicitReferences(
           S,
           [&](const ReferenceLoc &Loc) {
-            if (HasPostUse ||
-                SM.isBeforeInTranslationUnit(Loc.NameLoc, ZoneRange.getEnd()))
+            if (SM.isBeforeInTranslationUnit(Loc.NameLoc, ZoneRange.getEnd()))
               return;
-            HasPostUse = llvm::any_of(Loc.Targets,
-                                      [&DeclsInExtZone](const Decl *Target) {
-                                        return DeclsInExtZone.contains(Target);
-                                      });
+            const auto *const PostUseIter = llvm::find_if(
+                Loc.Targets, [&DeclsInExtZone](const Decl *Target) {
+                  return DeclsInExtZone.contains(Target);
+                });
+
+            if (const bool FoundPostUse = PostUseIter != Loc.Targets.end();
+                FoundPostUse) {
+              DeclsToHoist.insert(*PostUseIter);
+            }
           },
           Resolver);
-      if (HasPostUse)
-        return true;
     }
-    return false;
+    return DeclsToHoist;
   }
 };
 
@@ -368,16 +378,20 @@
   bool Static = false;
   ConstexprSpecKind Constexpr = ConstexprSpecKind::Unspecified;
   bool Const = false;
+  const HoistSet &ToHoist;
 
   // Decides whether the extracted function body and the function call need a
   // semicolon after extraction.
   tooling::ExtractionSemicolonPolicy SemicolonPolicy;
   const LangOptions *LangOpts;
-  NewFunction(tooling::ExtractionSemicolonPolicy SemicolonPolicy,
+  NewFunction(const HoistSet &ToHoist,
+              tooling::ExtractionSemicolonPolicy SemicolonPolicy,
               const LangOptions *LangOpts)
-      : SemicolonPolicy(SemicolonPolicy), LangOpts(LangOpts) {}
+      : ToHoist(ToHoist), SemicolonPolicy(SemicolonPolicy), LangOpts(LangOpts) {
+  }
   // Render the call for this function.
   std::string renderCall() const;
+  std::string renderHoistedCall() const;
   // Render the definition for this function.
   std::string renderDeclaration(FunctionDeclKind K,
                                 const DeclContext &SemanticDC,
@@ -463,7 +477,58 @@
   return llvm::formatv("{0}{1}", QualifierName, Name);
 }
 
+// Renders the HoistSet to a comma separated list or a single named decl.
+std::string renderHoistSet(const HoistSet &ToHoist) {
+  std::string Res{};
+  bool NeedsComma = false;
+  const auto Render = [&NeedsComma, &Res](const NamedDecl *const NDecl) {
+    if (NeedsComma) {
+      Res += ", ";
+    }
+    Res += NDecl->getNameAsString();
+  };
+  for (const NamedDecl *DeclToHoist : ToHoist) {
+    if (llvm::isa<VarDecl>(DeclToHoist) ||
+        llvm::isa<BindingDecl>(DeclToHoist)) {
+      Render(DeclToHoist);
+    }
+
+    NeedsComma = true;
+  }
+  return Res;
+}
+
+std::string NewFunction::renderHoistedCall() const {
+  auto HoistedVarDecls = std::string{};
+  auto ExplicitUnpacking = std::string{};
+  const auto HasStructuredBinding = LangOpts->CPlusPlus17;
+
+  if (ToHoist.size() > 1) {
+    if (HasStructuredBinding) {
+      HoistedVarDecls = "auto [" + renderHoistSet(ToHoist) + "] = ";
+    } else {
+      HoistedVarDecls = "auto returned = ";
+      auto DeclIter = ToHoist.begin();
+      for (size_t Index = 0U; Index < ToHoist.size(); ++Index, ++DeclIter) {
+        ExplicitUnpacking +=
+            llvm::formatv("\nauto {0} = std::get<{1}>(returned);",
+                          (*DeclIter)->getNameAsString(), Index);
+      }
+    }
+  } else {
+    HoistedVarDecls = "auto " + renderHoistSet(ToHoist) + " = ";
+  }
+
+  return std::string(llvm::formatv(
+      "{0}{1}({2}){3}{4}", HoistedVarDecls, Name, renderParametersForCall(),
+      (SemicolonPolicy.isNeededInOriginalFunction() ? ";" : ""),
+      ExplicitUnpacking));
+}
+
 std::string NewFunction::renderCall() const {
+  if (!ToHoist.empty())
+    return renderHoistedCall();
+
   return std::string(
       llvm::formatv("{0}{1}({2}){3}", CallerReturnsValue ? "return " : "", Name,
                     renderParametersForCall(),
@@ -496,8 +561,20 @@
   // - hoist decls
   // - add return statement
   // - Add semicolon
-  return toSourceCode(SM, BodyRange).str() +
-         (SemicolonPolicy.isNeededInExtractedFunction() ? ";" : "");
+  auto Body = toSourceCode(SM, BodyRange).str() +
+              (SemicolonPolicy.isNeededInExtractedFunction() ? ";" : "");
+  if (!ToHoist.empty()) {
+    if (const bool NeedsTupleOrPair = ToHoist.size() > 1; NeedsTupleOrPair) {
+      const auto NeedsPair = ToHoist.size() == 2;
+
+      Body += "\nreturn " +
+              std::string(NeedsPair ? "std::pair{" : "std::tuple{") +
+              renderHoistSet(ToHoist) + "};";
+    } else {
+      Body += "\nreturn " + renderHoistSet(ToHoist) + ";";
+    }
+  }
+  return Body;
 }
 
 std::string NewFunction::Parameter::render(const DeclContext *Context) const {
@@ -675,10 +752,6 @@
     const auto &DeclInfo = KeyVal.second;
     // If a Decl was Declared in zone and referenced in post zone, it
     // needs to be hoisted (we bail out in that case).
-    // FIXME: Support Decl Hoisting.
-    if (DeclInfo.DeclaredIn == ZoneRelative::Inside &&
-        DeclInfo.IsReferencedInPostZone)
-      return false;
     if (!DeclInfo.IsReferencedInZone)
       continue; // no need to pass as parameter, not referenced
     if (DeclInfo.DeclaredIn == ZoneRelative::Inside ||
@@ -724,6 +797,19 @@
   return SemicolonPolicy;
 }
 
+QualType getReturnTypeForHoisted(const FunctionDecl &EnclosingFunc,
+                                 const HoistSet &ToHoist) {
+  // Hoisting just one variable, use that variables type instead of auto
+  if (ToHoist.size() == 1) {
+    if (const auto *const VDecl = llvm::dyn_cast<VarDecl>(*ToHoist.begin());
+        VDecl != nullptr) {
+      return VDecl->getType();
+    }
+  }
+
+  return EnclosingFunc.getParentASTContext().getAutoDeductType();
+}
+
 // Generate return type for ExtractedFunc. Return false if unable to do so.
 bool generateReturnProperties(NewFunction &ExtractedFunc,
                               const FunctionDecl &EnclosingFunc,
@@ -745,7 +831,11 @@
     return true;
   }
   // FIXME: Generate new return statement if needed.
-  ExtractedFunc.ReturnType = EnclosingFunc.getParentASTContext().VoidTy;
+  ExtractedFunc.ReturnType =
+      !ExtractedFunc.ToHoist.empty()
+          ? getReturnTypeForHoisted(EnclosingFunc, ExtractedFunc.ToHoist)
+          : EnclosingFunc.getParentASTContext().VoidTy;
+
   return true;
 }
 
@@ -759,6 +849,7 @@
 // FIXME: add support for adding other function return types besides void.
 // FIXME: assign the value returned by non void extracted function.
 llvm::Expected<NewFunction> getExtractedFunction(ExtractionZone &ExtZone,
+                                                 const HoistSet &ToHoist,
                                                  const SourceManager &SM,
                                                  const LangOptions &LangOpts) {
   CapturedZoneInfo CapturedInfo = captureZoneInfo(ExtZone);
@@ -766,7 +857,7 @@
   if (CapturedInfo.BrokenControlFlow)
     return error("Cannot extract break/continue without corresponding "
                  "loop/switch statement.");
-  NewFunction ExtractedFunc(getSemicolonPolicy(ExtZone, SM, LangOpts),
+  NewFunction ExtractedFunc(ToHoist, getSemicolonPolicy(ExtZone, SM, LangOpts),
                             &LangOpts);
 
   ExtractedFunc.SyntacticDC =
@@ -815,6 +906,7 @@
 
 private:
   ExtractionZone ExtZone;
+  HoistSet ToHoist;
 };
 
 REGISTER_TWEAK(ExtractFunction)
@@ -880,8 +972,12 @@
       (hasReturnStmt(*MaybeExtZone) && !alwaysReturns(*MaybeExtZone)))
     return false;
 
-  // FIXME: Get rid of this check once we support hoisting.
-  if (MaybeExtZone->requiresHoisting(SM, Inputs.AST->getHeuristicResolver()))
+  ToHoist =
+      MaybeExtZone->getDeclsToHoist(SM, Inputs.AST->getHeuristicResolver());
+
+  const auto HasAutoReturnTypeDeduction = LangOpts.CPlusPlus14;
+  const auto RequiresPairOrTuple = ToHoist.size() > 1;
+  if (RequiresPairOrTuple && !HasAutoReturnTypeDeduction)
     return false;
 
   ExtZone = std::move(*MaybeExtZone);
@@ -891,7 +987,7 @@
 Expected<Tweak::Effect> ExtractFunction::apply(const Selection &Inputs) {
   const SourceManager &SM = Inputs.AST->getSourceManager();
   const LangOptions &LangOpts = Inputs.AST->getLangOpts();
-  auto ExtractedFunc = getExtractedFunction(ExtZone, SM, LangOpts);
+  auto ExtractedFunc = getExtractedFunction(ExtZone, ToHoist, SM, LangOpts);
   // FIXME: Add more types of errors.
   if (!ExtractedFunc)
     return ExtractedFunc.takeError();
@@ -914,8 +1010,8 @@
 
       tooling::Replacements OtherEdit(
           createForwardDeclaration(*ExtractedFunc, SM));
-      if (auto PathAndEdit = Tweak::Effect::fileEdit(SM, SM.getFileID(*FwdLoc),
-                                                 OtherEdit))
+      if (auto PathAndEdit =
+              Tweak::Effect::fileEdit(SM, SM.getFileID(*FwdLoc), OtherEdit))
         MultiFileEffect->ApplyEdits.try_emplace(PathAndEdit->first,
                                                 PathAndEdit->second);
       else
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to