5chmidti updated this revision to Diff 477184. 5chmidti added a comment. Fixup: rm added includes
Repository: rG LLVM Github Monorepo CHANGES SINCE LAST ACTION https://reviews.llvm.org/D138499/new/ https://reviews.llvm.org/D138499 Files: clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp clang-tools-extra/docs/ReleaseNotes.rst
Index: clang-tools-extra/docs/ReleaseNotes.rst =================================================================== --- clang-tools-extra/docs/ReleaseNotes.rst +++ clang-tools-extra/docs/ReleaseNotes.rst @@ -78,6 +78,9 @@ Miscellaneous ^^^^^^^^^^^^^ +- The extract function tweak gained support for hoisting, i.e. returning decls declared + inside the selection that are used outside of the selection. + Improvements to clang-doc ------------------------- Index: clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp =================================================================== --- clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp +++ clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp @@ -30,8 +30,9 @@ EXPECT_EQ(apply("auto lam = [](){ [[int x;]] }; "), "unavailable"); // Partial statements aren't extracted. EXPECT_THAT(apply("int [[x = 0]];"), "unavailable"); - // FIXME: Support hoisting. - EXPECT_THAT(apply(" [[int a = 5;]] a++; "), "unavailable"); + + // Extract regions that require hoisting + EXPECT_THAT(apply(" [[int a = 5;]] a++; "), HasSubstr("extracted")); // Ensure that end of Zone and Beginning of PostZone being adjacent doesn't // lead to break being included in the extraction zone. @@ -192,6 +193,202 @@ EXPECT_EQ(apply(CompoundFailInput), "unavailable"); } +TEST_F(ExtractFunctionTest, Hoisting) { + std::string HoistingInput = R"cpp( + int foo() { + int a = 3; + [[int x = 39 + a; + ++x; + int y = x * 2; + int z = 4;]] + return x + y + z; + } + )cpp"; + std::string HoistingOutput = R"cpp( + auto extracted(int &a) { +int x = 39 + a; + ++x; + int y = x * 2; + int z = 4; +return std::tuple{x, y, z}; +} +int foo() { + int a = 3; + auto [x, y, z] = extracted(a); + return x + y + z; + } + )cpp"; + EXPECT_EQ(apply(HoistingInput), HoistingOutput); + + std::string HoistingInput2 = R"cpp( + int foo() { + int a{}; + [[int b = a + 1;]] + return b; + } + )cpp"; + std::string HoistingOutput2 = R"cpp( + int extracted(int &a) { +int b = a + 1; +return b; +} +int foo() { + int a{}; + auto b = extracted(a); + return b; + } + )cpp"; + EXPECT_EQ(apply(HoistingInput2), HoistingOutput2); + + std::string HoistingInput3 = R"cpp( + int foo(int b) { + int a{}; + if (b == 42) { + [[a = 123; + return a + b;]] + } + a = 456; + return a; + } + )cpp"; + std::string HoistingOutput3 = R"cpp( + int extracted(int &b, int &a) { +a = 123; + return a + b; +} +int foo(int b) { + int a{}; + if (b == 42) { + return extracted(b, a); + } + a = 456; + return a; + } + )cpp"; + EXPECT_EQ(apply(HoistingInput3), HoistingOutput3); + + std::string HoistingInput4 = R"cpp( + struct A { + bool flag; + int val; + }; + A bar(); + int foo(int b) { + int a = 0; + [[auto [flag, val] = bar(); + int c = 4; + val = c + a;]] + return a + b + c + val; + } + )cpp"; + std::string HoistingOutput4 = R"cpp( + struct A { + bool flag; + int val; + }; + A bar(); + auto extracted(int &a) { +auto [flag, val] = bar(); + int c = 4; + val = c + a; +return std::pair{val, c}; +} +int foo(int b) { + int a = 0; + auto [val, c] = extracted(a); + return a + b + c + val; + } + )cpp"; + EXPECT_EQ(apply(HoistingInput4), HoistingOutput4); +} + +TEST_F(ExtractFunctionTest, HoistingCXX11) { + ExtraArgs.emplace_back("-std=c++11"); + std::string HoistingInput = R"cpp( + int foo() { + int a = 3; + [[int x = 39 + a; + ++x; + int y = x * 2; + int z = 4;]] + return x + y + z; + } + )cpp"; + EXPECT_THAT(apply(HoistingInput), HasSubstr("unavailable")); + + std::string HoistingInput2 = R"cpp( + int foo() { + int a; + [[int b = a + 1;]] + return b; + } + )cpp"; + std::string HoistingOutput2 = R"cpp( + int extracted(int &a) { +int b = a + 1; +return b; +} +int foo() { + int a; + auto b = extracted(a); + return b; + } + )cpp"; + EXPECT_EQ(apply(HoistingInput2), HoistingOutput2); +} + +TEST_F(ExtractFunctionTest, HoistingCXX14) { + ExtraArgs.emplace_back("-std=c++14"); + std::string HoistingInput = R"cpp( + int foo() { + int a = 3; + [[int x = 39 + a; + ++x; + int y = x * 2; + int z = 4;]] + return x + y + z; + } + )cpp"; + std::string HoistingOutput = R"cpp( + auto extracted(int &a) { +int x = 39 + a; + ++x; + int y = x * 2; + int z = 4; +return std::tuple{x, y, z}; +} +int foo() { + int a = 3; + auto returned = extracted(a); +auto x = std::get<0>(returned); +auto y = std::get<1>(returned); +auto z = std::get<2>(returned); + return x + y + z; + } + )cpp"; + EXPECT_EQ(apply(HoistingInput), HoistingOutput); + + std::string HoistingInput2 = R"cpp( + int foo() { + int a; + [[int b = a + 1;]] + return b; + } + )cpp"; + std::string HoistingOutput2 = R"cpp( + int extracted(int &a) { +int b = a + 1; +return b; +} +int foo() { + int a; + auto b = extracted(a); + return b; + } + )cpp"; + EXPECT_EQ(apply(HoistingInput2), HoistingOutput2); +} + TEST_F(ExtractFunctionTest, DifferentHeaderSourceTest) { Header = R"cpp( class SomeClass { Index: clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp =================================================================== --- clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp +++ clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp @@ -80,6 +80,13 @@ using Node = SelectionTree::Node; +struct HoistSetComparator { + bool operator()(const Decl *const Lhs, const Decl *const Rhs) const { + return Lhs->getLocation() < Rhs->getLocation(); + } +}; +using HoistSet = llvm::SmallSet<const NamedDecl *, 1, HoistSetComparator>; + // ExtractionZone is the part of code that is being extracted. // EnclosingFunction is the function/method inside which the zone lies. // We split the file into 4 parts relative to extraction zone. @@ -172,12 +179,13 @@ // semicolon after the extraction. const Node *getLastRootStmt() const { return Parent->Children.back(); } - // Checks if declarations inside extraction zone are accessed afterwards. + // Checks if declarations inside extraction zone are accessed afterwards and + // adds these declarations to the returned set. // // This performs a partial AST traversal proportional to the size of the // enclosing function, so it is possibly expensive. - bool requiresHoisting(const SourceManager &SM, - const HeuristicResolver *Resolver) const { + HoistSet getDeclsToHoist(const SourceManager &SM, + const HeuristicResolver *Resolver) const { // First find all the declarations that happened inside extraction zone. llvm::SmallSet<const Decl *, 1> DeclsInExtZone; for (auto *RootStmt : RootStmts) { @@ -192,29 +200,31 @@ } // Early exit without performing expensive traversal below. if (DeclsInExtZone.empty()) - return false; - // Then make sure they are not used outside the zone. + return {}; + // Add any decl used after the selection to the returned set + HoistSet DeclsToHoist{}; for (const auto *S : EnclosingFunction->getBody()->children()) { if (SM.isBeforeInTranslationUnit(S->getSourceRange().getEnd(), ZoneRange.getEnd())) continue; - bool HasPostUse = false; findExplicitReferences( S, [&](const ReferenceLoc &Loc) { - if (HasPostUse || - SM.isBeforeInTranslationUnit(Loc.NameLoc, ZoneRange.getEnd())) + if (SM.isBeforeInTranslationUnit(Loc.NameLoc, ZoneRange.getEnd())) return; - HasPostUse = llvm::any_of(Loc.Targets, - [&DeclsInExtZone](const Decl *Target) { - return DeclsInExtZone.contains(Target); - }); + const auto *const PostUseIter = llvm::find_if( + Loc.Targets, [&DeclsInExtZone](const Decl *Target) { + return DeclsInExtZone.contains(Target); + }); + + if (const bool FoundPostUse = PostUseIter != Loc.Targets.end(); + FoundPostUse) { + DeclsToHoist.insert(*PostUseIter); + } }, Resolver); - if (HasPostUse) - return true; } - return false; + return DeclsToHoist; } }; @@ -368,16 +378,20 @@ bool Static = false; ConstexprSpecKind Constexpr = ConstexprSpecKind::Unspecified; bool Const = false; + const HoistSet &ToHoist; // Decides whether the extracted function body and the function call need a // semicolon after extraction. tooling::ExtractionSemicolonPolicy SemicolonPolicy; const LangOptions *LangOpts; - NewFunction(tooling::ExtractionSemicolonPolicy SemicolonPolicy, + NewFunction(const HoistSet &ToHoist, + tooling::ExtractionSemicolonPolicy SemicolonPolicy, const LangOptions *LangOpts) - : SemicolonPolicy(SemicolonPolicy), LangOpts(LangOpts) {} + : ToHoist(ToHoist), SemicolonPolicy(SemicolonPolicy), LangOpts(LangOpts) { + } // Render the call for this function. std::string renderCall() const; + std::string renderHoistedCall() const; // Render the definition for this function. std::string renderDeclaration(FunctionDeclKind K, const DeclContext &SemanticDC, @@ -463,7 +477,58 @@ return llvm::formatv("{0}{1}", QualifierName, Name); } +// Renders the HoistSet to a comma separated list or a single named decl. +std::string renderHoistSet(const HoistSet &ToHoist) { + std::string Res{}; + bool NeedsComma = false; + const auto Render = [&NeedsComma, &Res](const NamedDecl *const NDecl) { + if (NeedsComma) { + Res += ", "; + } + Res += NDecl->getNameAsString(); + }; + for (const NamedDecl *DeclToHoist : ToHoist) { + if (llvm::isa<VarDecl>(DeclToHoist) || + llvm::isa<BindingDecl>(DeclToHoist)) { + Render(DeclToHoist); + } + + NeedsComma = true; + } + return Res; +} + +std::string NewFunction::renderHoistedCall() const { + auto HoistedVarDecls = std::string{}; + auto ExplicitUnpacking = std::string{}; + const auto HasStructuredBinding = LangOpts->CPlusPlus17; + + if (ToHoist.size() > 1) { + if (HasStructuredBinding) { + HoistedVarDecls = "auto [" + renderHoistSet(ToHoist) + "] = "; + } else { + HoistedVarDecls = "auto returned = "; + auto DeclIter = ToHoist.begin(); + for (size_t Index = 0U; Index < ToHoist.size(); ++Index, ++DeclIter) { + ExplicitUnpacking += + llvm::formatv("\nauto {0} = std::get<{1}>(returned);", + (*DeclIter)->getNameAsString(), Index); + } + } + } else { + HoistedVarDecls = "auto " + renderHoistSet(ToHoist) + " = "; + } + + return std::string(llvm::formatv( + "{0}{1}({2}){3}{4}", HoistedVarDecls, Name, renderParametersForCall(), + (SemicolonPolicy.isNeededInOriginalFunction() ? ";" : ""), + ExplicitUnpacking)); +} + std::string NewFunction::renderCall() const { + if (!ToHoist.empty()) + return renderHoistedCall(); + return std::string( llvm::formatv("{0}{1}({2}){3}", CallerReturnsValue ? "return " : "", Name, renderParametersForCall(), @@ -496,8 +561,20 @@ // - hoist decls // - add return statement // - Add semicolon - return toSourceCode(SM, BodyRange).str() + - (SemicolonPolicy.isNeededInExtractedFunction() ? ";" : ""); + auto Body = toSourceCode(SM, BodyRange).str() + + (SemicolonPolicy.isNeededInExtractedFunction() ? ";" : ""); + if (!ToHoist.empty()) { + if (const bool NeedsTupleOrPair = ToHoist.size() > 1; NeedsTupleOrPair) { + const auto NeedsPair = ToHoist.size() == 2; + + Body += "\nreturn " + + std::string(NeedsPair ? "std::pair{" : "std::tuple{") + + renderHoistSet(ToHoist) + "};"; + } else { + Body += "\nreturn " + renderHoistSet(ToHoist) + ";"; + } + } + return Body; } std::string NewFunction::Parameter::render(const DeclContext *Context) const { @@ -675,10 +752,6 @@ const auto &DeclInfo = KeyVal.second; // If a Decl was Declared in zone and referenced in post zone, it // needs to be hoisted (we bail out in that case). - // FIXME: Support Decl Hoisting. - if (DeclInfo.DeclaredIn == ZoneRelative::Inside && - DeclInfo.IsReferencedInPostZone) - return false; if (!DeclInfo.IsReferencedInZone) continue; // no need to pass as parameter, not referenced if (DeclInfo.DeclaredIn == ZoneRelative::Inside || @@ -724,6 +797,19 @@ return SemicolonPolicy; } +QualType getReturnTypeForHoisted(const FunctionDecl &EnclosingFunc, + const HoistSet &ToHoist) { + // Hoisting just one variable, use that variables type instead of auto + if (ToHoist.size() == 1) { + if (const auto *const VDecl = llvm::dyn_cast<VarDecl>(*ToHoist.begin()); + VDecl != nullptr) { + return VDecl->getType(); + } + } + + return EnclosingFunc.getParentASTContext().getAutoDeductType(); +} + // Generate return type for ExtractedFunc. Return false if unable to do so. bool generateReturnProperties(NewFunction &ExtractedFunc, const FunctionDecl &EnclosingFunc, @@ -745,7 +831,11 @@ return true; } // FIXME: Generate new return statement if needed. - ExtractedFunc.ReturnType = EnclosingFunc.getParentASTContext().VoidTy; + ExtractedFunc.ReturnType = + !ExtractedFunc.ToHoist.empty() + ? getReturnTypeForHoisted(EnclosingFunc, ExtractedFunc.ToHoist) + : EnclosingFunc.getParentASTContext().VoidTy; + return true; } @@ -759,6 +849,7 @@ // FIXME: add support for adding other function return types besides void. // FIXME: assign the value returned by non void extracted function. llvm::Expected<NewFunction> getExtractedFunction(ExtractionZone &ExtZone, + const HoistSet &ToHoist, const SourceManager &SM, const LangOptions &LangOpts) { CapturedZoneInfo CapturedInfo = captureZoneInfo(ExtZone); @@ -766,7 +857,7 @@ if (CapturedInfo.BrokenControlFlow) return error("Cannot extract break/continue without corresponding " "loop/switch statement."); - NewFunction ExtractedFunc(getSemicolonPolicy(ExtZone, SM, LangOpts), + NewFunction ExtractedFunc(ToHoist, getSemicolonPolicy(ExtZone, SM, LangOpts), &LangOpts); ExtractedFunc.SyntacticDC = @@ -815,6 +906,7 @@ private: ExtractionZone ExtZone; + HoistSet ToHoist; }; REGISTER_TWEAK(ExtractFunction) @@ -880,8 +972,12 @@ (hasReturnStmt(*MaybeExtZone) && !alwaysReturns(*MaybeExtZone))) return false; - // FIXME: Get rid of this check once we support hoisting. - if (MaybeExtZone->requiresHoisting(SM, Inputs.AST->getHeuristicResolver())) + ToHoist = + MaybeExtZone->getDeclsToHoist(SM, Inputs.AST->getHeuristicResolver()); + + const auto HasAutoReturnTypeDeduction = LangOpts.CPlusPlus14; + const auto RequiresPairOrTuple = ToHoist.size() > 1; + if (RequiresPairOrTuple && !HasAutoReturnTypeDeduction) return false; ExtZone = std::move(*MaybeExtZone); @@ -891,7 +987,7 @@ Expected<Tweak::Effect> ExtractFunction::apply(const Selection &Inputs) { const SourceManager &SM = Inputs.AST->getSourceManager(); const LangOptions &LangOpts = Inputs.AST->getLangOpts(); - auto ExtractedFunc = getExtractedFunction(ExtZone, SM, LangOpts); + auto ExtractedFunc = getExtractedFunction(ExtZone, ToHoist, SM, LangOpts); // FIXME: Add more types of errors. if (!ExtractedFunc) return ExtractedFunc.takeError(); @@ -914,8 +1010,8 @@ tooling::Replacements OtherEdit( createForwardDeclaration(*ExtractedFunc, SM)); - if (auto PathAndEdit = Tweak::Effect::fileEdit(SM, SM.getFileID(*FwdLoc), - OtherEdit)) + if (auto PathAndEdit = + Tweak::Effect::fileEdit(SM, SM.getFileID(*FwdLoc), OtherEdit)) MultiFileEffect->ApplyEdits.try_emplace(PathAndEdit->first, PathAndEdit->second); else
_______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits