njames93 updated this revision to Diff 318678.
njames93 edited the summary of this revision.
njames93 added a comment.

Fix failing tests.
Updated message for tweak from a specified base class.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D94942/new/

https://reviews.llvm.org/D94942

Files:
  clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt
  clang-tools-extra/clangd/refactor/tweaks/ImplementAbstract.cpp
  clang-tools-extra/clangd/unittests/CMakeLists.txt
  clang-tools-extra/clangd/unittests/tweaks/ImplementAbstractTests.cpp

Index: clang-tools-extra/clangd/unittests/tweaks/ImplementAbstractTests.cpp
===================================================================
--- /dev/null
+++ clang-tools-extra/clangd/unittests/tweaks/ImplementAbstractTests.cpp
@@ -0,0 +1,393 @@
+//===-- ImplementAbstractTests.cpp ------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestTU.h"
+#include "TweakTesting.h"
+#include "gmock/gmock-matchers.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+using ::testing::Not;
+
+namespace clang {
+namespace clangd {
+namespace {
+
+TWEAK_TEST(ImplementAbstract);
+
+TEST_F(ImplementAbstractTest, TestUnavailable) {
+
+  StringRef Cases[]{
+      // Not a pure virtual method.
+      R"cpp(
+      class A {
+        virtual void Foo();
+      };
+      class ^B : public A {};
+    )cpp",
+      // Pure virtual method overridden in class.
+      R"cpp(
+      class A {
+        virtual void Foo() = 0;
+      };
+      class ^B : public A {
+        void Foo() override;
+      };
+    )cpp",
+      // Pure virtual method overridden in class with virtual keyword
+      R"cpp(
+      class A {
+        virtual void Foo() = 0;
+      };
+      class ^B : public A {
+        virtual void Foo() override;
+      };
+    )cpp",
+      // Pure virtual method overridden in class without override keyword
+      R"cpp(
+      class A {
+        virtual void Foo() = 0;
+      };
+      class ^B : public A {
+        void Foo();
+      };
+    )cpp",
+      // Pure virtual method overriden in base class.
+      R"cpp(
+      class A {
+        virtual void Foo() = 0;
+      };
+      class B : public A {
+        void Foo() override;
+      };
+      class ^C : public B {
+      };
+    )cpp"};
+  for (const auto &Case : Cases) {
+    EXPECT_THAT(Case, Not(isAvailable()));
+  }
+}
+
+TEST_F(ImplementAbstractTest, NormalAvailable) {
+  struct Case {
+    llvm::StringRef TestHeader;
+    llvm::StringRef TestSource;
+    llvm::StringRef ExpectedSource;
+  };
+
+  Case Cases[]{
+      {
+          R"cpp(
+            class A {
+              virtual void Foo() = 0;
+            };)cpp",
+          R"cpp(
+            class B : public A {^};
+          )cpp",
+          R"cpp(
+            class B : public A {
+void Foo() override;
+};
+          )cpp",
+      },
+      {
+          R"cpp(
+            class A {
+              public:
+              virtual void Foo() = 0;
+            };)cpp",
+          R"cpp(
+            class ^B : public A {};
+          )cpp",
+          R"cpp(
+            class B : public A {
+public:
+
+void Foo() override;
+};
+          )cpp",
+      },
+      {
+          R"cpp(
+            class A {
+              virtual void Foo(int Param) = 0;
+            };)cpp",
+          R"cpp(
+            class ^B : public A {};
+          )cpp",
+          R"cpp(
+            class B : public A {
+void Foo(int Param) override;
+};
+          )cpp",
+      },
+      {
+          R"cpp(
+            class A {
+              virtual void Foo(int Param) = 0;
+            };)cpp",
+          R"cpp(
+            struct ^B : public A {};
+          )cpp",
+          R"cpp(
+            struct B : public A {
+private:
+
+void Foo(int Param) override;
+};
+          )cpp",
+      },
+      {
+          R"cpp(
+            class A {
+              virtual void Foo(int Param) const volatile = 0;
+              public:
+              virtual void Bar(int Param) = 0;
+            };)cpp",
+          R"cpp(
+            class ^B : public A {
+              void Foo(int Param) const volatile override;
+            };
+          )cpp",
+          R"cpp(
+            class B : public A {
+              void Foo(int Param) const volatile override;
+            
+public:
+
+void Bar(int Param) override;
+};
+          )cpp",
+      },
+      {
+          R"cpp(
+                 class A {
+              virtual void Foo() = 0;
+              virtual void Bar() = 0;
+            };
+            class B : public A {
+              void Foo() override;
+            };
+              )cpp",
+          R"cpp(
+                class ^C : public B {
+                  virtual void Baz();
+                };
+              )cpp",
+          R"cpp(
+                class C : public B {
+                  virtual void Baz();
+void Bar() override;
+
+                };
+              )cpp",
+      },
+      {
+          R"cpp(
+            class A {
+              virtual void Foo() = 0;
+            };)cpp",
+          R"cpp(
+            class ^B : public A {
+              ~B();
+            };
+          )cpp",
+          R"cpp(
+            class B : public A {
+void Foo() override;
+
+              ~B();
+            };
+          )cpp",
+      },
+      {
+          R"cpp(
+            class A {
+              virtual void Foo() = 0;
+              public:
+              virtual void Bar() = 0;
+            };)cpp",
+          R"cpp(
+            class ^B : public A {
+            };
+          )cpp",
+          R"cpp(
+            class B : public A {
+void Foo() override;
+
+            
+public:
+
+void Bar() override;
+};
+          )cpp",
+      },
+      {
+          R"cpp(
+            class A {
+              virtual void Foo() = 0;
+            };
+            struct B : public A {
+              virtual void Bar() = 0;
+            };)cpp",
+          R"cpp(
+            class ^C : public B {
+            };
+          )cpp",
+          R"cpp(
+            class C : public B {
+void Foo() override;
+
+            
+public:
+
+void Bar() override;
+};
+          )cpp",
+      },
+      {
+          R"cpp(
+                  class A {
+                    virtual void Foo() = 0;
+                  };
+                  struct B : public A {
+                    virtual void Bar() = 0;
+                  };)cpp",
+          R"cpp(
+                  class ^C : private B {
+                  };
+                )cpp",
+          R"cpp(
+                  class C : private B {
+void Foo() override;
+void Bar() override;
+
+                  };
+                )cpp",
+      },
+      {
+          R"cpp(
+            struct A {
+              virtual void Foo() = 0;
+            };
+            struct B {
+              virtual void Bar() = 0;
+            };)cpp",
+          R"cpp(
+            class C : public ^A, B {
+            };
+          )cpp",
+          R"cpp(
+            class C : public A, B {
+            
+public:
+
+void Foo() override;
+};
+          )cpp",
+      },
+      {
+          R"cpp(
+            struct A {
+              virtual void Foo() = 0;
+            };
+            struct B {
+              virtual void Bar() = 0;
+            };)cpp",
+          R"cpp(
+            class ^C : public A, B {
+            };
+          )cpp",
+          R"cpp(
+            class C : public A, B {
+void Bar() override;
+
+            
+public:
+
+void Foo() override;
+};
+          )cpp",
+      },
+  };
+
+  for (const auto &Case : Cases) {
+    Header = Case.TestHeader.str();
+    EXPECT_EQ(apply(Case.TestSource), Case.ExpectedSource);
+  }
+}
+
+TEST_F(ImplementAbstractTest, TemplateUnavailable) {
+  StringRef Cases[]{
+      R"cpp(
+        template<typename T>
+        class A {
+          virtual void Foo() = 0;
+        };
+        template<typename T>
+        class ^B : public A<T> {};
+        )cpp",
+      R"cpp(
+        template<typename T>
+        class ^B : public T {};
+        )cpp",
+  };
+  for (const auto &Case : Cases) {
+    EXPECT_THAT(Case, Not(isAvailable()));
+  }
+}
+
+TEST_F(ImplementAbstractTest, TemplateAvailable) {
+  struct Case {
+    llvm::StringRef TestHeader;
+    llvm::StringRef TestSource;
+    llvm::StringRef ExpectedSource;
+  };
+  Case Cases[]{
+      {
+          R"cpp(
+            template<typename T>
+            class A {
+              virtual void Foo() = 0;
+            };
+            )cpp",
+          R"cpp(
+            class ^B : public A<int> {};
+            )cpp",
+          R"cpp(
+            class B : public A<int> {
+void Foo() override;
+};
+            )cpp",
+      },
+      {
+          R"cpp(
+            class A {
+              virtual void Foo() = 0;
+            };)cpp",
+          R"cpp(
+            template<typename T>
+            class ^B : public A {};
+            )cpp",
+          R"cpp(
+            template<typename T>
+            class B : public A {
+void Foo() override;
+};
+            )cpp",
+      },
+  };
+  for (const auto &Case : Cases) {
+    Header = Case.TestHeader.str();
+    EXPECT_EQ(apply(Case.TestSource), Case.ExpectedSource);
+  }
+}
+
+} // namespace
+} // namespace clangd
+} // namespace clang
Index: clang-tools-extra/clangd/unittests/CMakeLists.txt
===================================================================
--- clang-tools-extra/clangd/unittests/CMakeLists.txt
+++ clang-tools-extra/clangd/unittests/CMakeLists.txt
@@ -118,6 +118,7 @@
   tweaks/ExpandMacroTests.cpp
   tweaks/ExtractFunctionTests.cpp
   tweaks/ExtractVariableTests.cpp
+  tweaks/ImplementAbstractTests.cpp
   tweaks/ObjCLocalizeStringLiteralTests.cpp
   tweaks/PopulateSwitchTests.cpp
   tweaks/RawStringLiteralTests.cpp
Index: clang-tools-extra/clangd/refactor/tweaks/ImplementAbstract.cpp
===================================================================
--- /dev/null
+++ clang-tools-extra/clangd/refactor/tweaks/ImplementAbstract.cpp
@@ -0,0 +1,385 @@
+//===--- ImplementAbstract.cpp -----------------------------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "refactor/Tweak.h"
+#include "support/Logger.h"
+#include "clang/Basic/Specifiers.h"
+#include "llvm/ADT/PointerIntPair.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace clang {
+namespace clangd {
+
+namespace {
+
+using MethodAndAccess =
+    llvm::PointerIntPair<const CXXMethodDecl *, 2, AccessSpecifier>;
+
+AccessSpecifier getMostConstrained(AccessSpecifier InheritSpecifier,
+                                   AccessSpecifier DefinedAs) {
+  return std::max(InheritSpecifier, DefinedAs);
+}
+
+/// Stores all pure methods in \p Record that aren't in \p Overrides in \p
+/// Results. The methods are stored the most constrained access of \p Access and
+/// the AccessSpecifier of the method.
+void collectNonOverriddenPureMethods(
+    const CXXRecordDecl &Record,
+    llvm::SmallVectorImpl<MethodAndAccess> &Results, AccessSpecifier Access,
+    const llvm::SmallPtrSetImpl<const CXXMethodDecl *> &Overrides) {
+  for (const CXXMethodDecl *Method : Record.methods()) {
+    if (!Method->isPure())
+      continue;
+    if (!Overrides.contains(Method))
+      Results.emplace_back(Method,
+                           getMostConstrained(Access, Method->getAccess()));
+  }
+}
+
+/// Populates \p Overrides with all the methods that are overridden by methods
+/// in \p Record. If \p IsRoot is true and there are any pure methods in \p
+/// Record, return true, otherwise return false.
+bool buildOverrideSet(const CXXRecordDecl &Record,
+                      llvm::SmallPtrSetImpl<const CXXMethodDecl *> &Overrides,
+                      bool IsRoot) {
+  for (const CXXMethodDecl *Method : Record.methods()) {
+    if (!Method->isVirtual())
+      continue;
+    if (IsRoot && Method->isPure())
+      return true;
+    for (const auto *Overriding : Method->overridden_methods())
+      Overrides.insert(Overriding);
+  }
+  return false;
+}
+
+/// Collect all the pure virtual methods in \p Record and its base classes that
+/// don't appear in \p Overrides, store the results in \p Results. Returns true
+/// if any of the bases are dependent, otherwise false.
+bool collectPureMethodsImpl(
+    const CXXRecordDecl &Record,
+    llvm::SmallVectorImpl<MethodAndAccess> &Results, AccessSpecifier Access,
+    llvm::SmallPtrSetImpl<const CXXMethodDecl *> &Overrides) {
+  if (Record.getNumBases() > 0) {
+    buildOverrideSet(Record, Overrides, false);
+    for (const CXXBaseSpecifier &Base : Record.bases()) {
+      const RecordType *RT = Base.getType()->getAs<RecordType>();
+      if (!RT)
+        // Probably a dependent base, just error out.
+        return true;
+      const CXXRecordDecl *BaseDecl = cast<CXXRecordDecl>(RT->getDecl());
+      if (!BaseDecl->isPolymorphic())
+        continue;
+      if (collectPureMethodsImpl(
+              *BaseDecl, Results,
+              getMostConstrained(Access, Base.getAccessSpecifier()), Overrides))
+        // Propergate any error back up.
+        return true;
+    }
+  }
+  // Add the Pure methods from this class after traversing the bases. This means
+  // when it comes time to create implementation, methods from classes higher up
+  // the heirachy will appear first.
+  collectNonOverriddenPureMethods(Record, Results, Access, Overrides);
+  return false;
+}
+
+/// Collect all the pure virtual methods from the base class \p Base that
+/// haven't been overridden in \p Record. Store the results in \p Results.
+bool collectPureMethodsFromBase(
+    const CXXRecordDecl &RD, const CXXBaseSpecifier &Base,
+    llvm::SmallVectorImpl<MethodAndAccess> &Results) {
+  assert(llvm::any_of(RD.bases(), [&Base](const CXXBaseSpecifier &Base2) {
+    // CXXBaseSpecifier has no operator== and as DynTypedNode holds a copy, we
+    // can't use pointer identity. This check should ensure the base we have
+    // selected comes from RD.
+    return Base.getTypeSourceInfo() == Base2.getTypeSourceInfo() &&
+           Base.getSourceRange() == Base2.getSourceRange();
+  }));
+  const RecordType *RT = Base.getType()->getAs<RecordType>();
+  if (!RT)
+    // Probably a dependent base, just error out.
+    return true;
+  const CXXRecordDecl *BaseDecl = cast<CXXRecordDecl>(RT->getDecl());
+  if (!BaseDecl->isPolymorphic())
+    return true;
+  llvm::SmallPtrSet<const CXXMethodDecl *, 16> Overrides;
+  if (buildOverrideSet(RD, Overrides, true))
+    return true;
+  return collectPureMethodsImpl(*BaseDecl, Results, Base.getAccessSpecifier(),
+                                Overrides);
+}
+
+bool collectAllPureMethods(const CXXRecordDecl &RD,
+                           llvm::SmallVectorImpl<MethodAndAccess> &Results) {
+  llvm::SmallPtrSet<const CXXMethodDecl *, 16> Overrides;
+  buildOverrideSet(RD, Overrides, true);
+  return collectPureMethodsImpl(RD, Results, AS_public, Overrides);
+}
+
+/// Gets the class at the Selection \p Inputs. If the selection is in
+/// the base-specifier-list, The base that it's over will be stored in \p
+/// BaseSpec. \returns nullptr if no class could be found.
+const CXXRecordDecl *getSelectedRecord(const Tweak::Selection &Inputs,
+                                       Optional<CXXBaseSpecifier> *BaseSpec) {
+  if (const SelectionTree::Node *Node = Inputs.ASTSelection.commonAncestor()) {
+    if (const auto *RD = Node->ASTNode.get<CXXRecordDecl>())
+      return RD;
+    if (const auto *BS = Node->ASTNode.get<CXXBaseSpecifier>()) {
+      if (SelectionTree::Node *Parent = Node->Parent) {
+        if (const auto *RD = Parent->ASTNode.get<CXXRecordDecl>()) {
+          if (BaseSpec)
+            *BaseSpec = *BS;
+          return RD;
+        }
+      }
+    }
+  }
+  return nullptr;
+}
+
+/// Some quick to check basic heuristics to check before we try and collect
+/// virtual methods.
+bool isClassOK(const CXXRecordDecl &RecordDecl) {
+  if (!RecordDecl.isThisDeclarationADefinition())
+    return false;
+  if (!RecordDecl.isClass() && !RecordDecl.isStruct())
+    return false;
+  if (RecordDecl.hasAnyDependentBases() || RecordDecl.getNumBases() == 0)
+    return false;
+  // We should check for abstract, but that prevents working on template classes
+  // that don't have any dependent bases.
+  if (!RecordDecl.isPolymorphic())
+    return false;
+  return true;
+}
+
+struct InsertionDetail {
+  SourceLocation Loc = {};
+  AccessSpecifier Access;
+  unsigned char AfterPriority = 0;
+};
+
+// This is a little hacky because EndLoc of a decl doesn't include
+// the semi-colon.
+auto getLocAfterDecl(const Decl &D, const SourceManager &SM,
+                     const LangOptions &LO) {
+  if (D.hasBody())
+    return D.getEndLoc().getLocWithOffset(1);
+  if (auto Next = Lexer::findNextToken(D.getEndLoc(), SM, LO)) {
+    if (Next->is(tok::semi))
+      return Next->getEndLoc();
+  }
+  return D.getEndLoc().getLocWithOffset(1);
+}
+
+/// Generate insertion points in \p R that don't require inserting access
+/// specifiers. The insertion points generally try to appear after the last
+/// method declared in the class with a specific access. \p ShouldIncludeAccess
+/// is a way to avoid generating insertion points for access specifiers we
+/// aren't going to fill in.
+SmallVector<InsertionDetail, 3>
+getInsertionPoints(const CXXRecordDecl &R, ArrayRef<bool> ShouldIncludeAccess,
+                   const SourceManager &SM, const LangOptions &LO) {
+  SmallVector<InsertionDetail, 3> Result;
+  auto GetDetailForAccess = [&](AccessSpecifier Spec) -> InsertionDetail & {
+    assert(Spec != AS_none);
+    for (InsertionDetail &Item : Result) {
+      if (Item.Access == Spec)
+        return Item;
+    }
+    return Result.emplace_back(InsertionDetail{{}, Spec});
+  };
+
+  // This whole block is designed to get an insertion point after the last
+  // method has been declared with each access specifier. Doing this ensures we
+  // keep the same visibility for implemented methods without the need to add
+  // unnecessary access specifiers.
+  for (auto *Decl : R.decls()) {
+    if (!ShouldIncludeAccess[Decl->getAccess()])
+      continue;
+    // Ignore things like compiler generated special member functions.
+    if (Decl->isImplicit())
+      continue;
+    // Hack to try and leave the destructor as last method in a block.
+    if (isa<CXXDestructorDecl>(Decl))
+      continue;
+    InsertionDetail &Detail = GetDetailForAccess(Decl->getAccess());
+    if (isa<CXXMethodDecl>(Decl)) {
+      Detail.Loc = getLocAfterDecl(*Decl, SM, LO);
+      Detail.AfterPriority = 2;
+    } else {
+      // Try to put methods after access spec but before fields.
+      auto Priority = isa<AccessSpecDecl>(Decl) ? 1 : 0;
+      if (Detail.AfterPriority <= Priority) {
+        Detail.Loc = getLocAfterDecl(*Decl, SM, LO);
+        Detail.AfterPriority = Priority;
+      }
+    }
+  }
+  if (Result.empty()) {
+    auto Access = R.isClass() ? AS_private : AS_public;
+    if (ShouldIncludeAccess[Access]) {
+      // An empty class so start inserting methods that don't need an access
+      // specifier just after the open curly brace.
+      GetDetailForAccess(Access).Loc =
+          R.getBraceRange().getBegin().getLocWithOffset(1);
+    }
+  }
+  return Result;
+}
+
+class PrintingInContextCallback : public PrintingCallbacks {
+public:
+  PrintingInContextCallback(const DeclContext *CurContext)
+      : CurContext(CurContext) {}
+  virtual ~PrintingInContextCallback() = default;
+  bool isScopeVisible(const DeclContext *DC) const override {
+    return DC->Encloses(CurContext);
+  }
+
+private:
+  const DeclContext *CurContext;
+};
+
+void printMethods(llvm::raw_ostream &Out, ArrayRef<MethodAndAccess> Items,
+                  AccessSpecifier AccessKind, const CXXRecordDecl *PrintContext,
+                  bool PrintAccessSpec) {
+
+  PrintingInContextCallback Callbacks(PrintContext);
+  auto Policy = PrintContext->getASTContext().getPrintingPolicy();
+  Policy.SuppressScope = false;
+  Policy.Callbacks = &Callbacks;
+  if (PrintAccessSpec)
+    Out << "\n" << getAccessSpelling(AccessKind) << ":\n";
+  Out << "\n";
+  for (const auto &MethodAndAccess : Items) {
+    if (MethodAndAccess.getInt() != AccessKind)
+      continue;
+    const CXXMethodDecl *Method = MethodAndAccess.getPointer();
+    Method->getReturnType().print(Out, Policy);
+    Out << ' ';
+    Out << Method->getNameAsString() << "(";
+    bool IsFirst = true;
+    for (const auto &Param : Method->parameters()) {
+      if (!IsFirst)
+        Out << ", ";
+      else
+        IsFirst = false;
+      Param->print(Out, Policy);
+    }
+    Out << ") ";
+    if (Method->isConst())
+      Out << "const ";
+    if (Method->isVolatile())
+      Out << "volatile ";
+    // Always suggest `override` over `final`.
+    Out << "override;\n";
+  }
+}
+
+class ImplementAbstract : public Tweak {
+public:
+  const char *id() const override;
+
+  bool prepare(const Selection &Inputs) override {
+    Selected = getSelectedRecord(Inputs, &FromBase);
+    if (!Selected)
+      return false;
+    if (!isClassOK(*Selected))
+      return false;
+    if (FromBase) {
+      if (collectPureMethodsFromBase(*Selected, *FromBase, PureVirtualMethods))
+        return false;
+    } else {
+      if (collectAllPureMethods(*Selected, PureVirtualMethods))
+        return false;
+    }
+    return !PureVirtualMethods.empty();
+  }
+
+  Expected<Effect> apply(const Selection &Inputs) override {
+    // We should have at least one pure virtual method to add.
+    assert(!PureVirtualMethods.empty() &&
+           "Prepare returned true when no methodx existed");
+    bool AccessNeedsProcessing[3] = {0};
+    for (auto Item : PureVirtualMethods) {
+      AccessNeedsProcessing[Item.getInt()] = true;
+    }
+
+    auto InsertionPoints = getInsertionPoints(*Selected, AccessNeedsProcessing,
+                                              Inputs.AST->getSourceManager(),
+                                              Inputs.AST->getLangOpts());
+    SmallString<256> Buffer;
+    llvm::raw_svector_ostream OS(Buffer);
+    tooling::Replacements Replacements;
+    for (auto &Item : InsertionPoints) {
+      assert(Item.Loc.isValid());
+      if (!AccessNeedsProcessing[Item.Access])
+        continue;
+      AccessNeedsProcessing[Item.Access] = false;
+      printMethods(OS, PureVirtualMethods, Item.Access, Selected,
+                   /*PrintAccessSpec=*/false);
+      if (auto Err = Replacements.add(tooling::Replacement(
+              Inputs.AST->getSourceManager(), Item.Loc, 0, Buffer))) {
+        return std::move(Err);
+      }
+      Buffer.clear();
+    }
+
+    // Any access specifiers not convered can be added in one insertion.
+    for (AccessSpecifier Spec : {AS_public, AS_protected, AS_private}) {
+      if (!AccessNeedsProcessing[Spec])
+        continue;
+      printMethods(OS, PureVirtualMethods, Spec, Selected,
+                   /*PrintAccessSpec=*/true);
+    }
+    if (!Buffer.empty()) {
+      if (auto Err = Replacements.add(tooling::Replacement(
+              Inputs.AST->getSourceManager(),
+              Selected->getBraceRange().getEnd(), 0, Buffer))) {
+        return std::move(Err);
+      }
+    }
+    return Effect::mainFileEdit(Inputs.AST->getASTContext().getSourceManager(),
+                                std::move(Replacements));
+  }
+
+  std::string title() const override {
+    if (FromBase) {
+      assert(Selected);
+      PrintingInContextCallback Callbacks(Selected->getDeclContext());
+      auto Policy = Selected->getParentASTContext().getPrintingPolicy();
+      Policy.SuppressScope = false;
+      Policy.Callbacks = &Callbacks;
+      std::string Result = "Implement pure virtual methods from '";
+      llvm::raw_string_ostream OS(Result);
+      FromBase->getTypeSourceInfo()->getType().print(OS, Policy);
+      OS << '\'';
+      OS.flush();
+      return Result;
+    }
+    return "Implement pure virtual methods";
+  }
+
+  llvm::StringLiteral kind() const override {
+    return CodeAction::REFACTOR_KIND;
+  }
+
+private:
+  const CXXRecordDecl *Selected;
+  llvm::SmallVector<MethodAndAccess, 0> PureVirtualMethods;
+  llvm::Optional<CXXBaseSpecifier> FromBase;
+};
+
+REGISTER_TWEAK(ImplementAbstract)
+
+} // namespace
+} // namespace clangd
+} // namespace clang
Index: clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt
===================================================================
--- clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt
+++ clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt
@@ -21,6 +21,7 @@
   ExpandMacro.cpp
   ExtractFunction.cpp
   ExtractVariable.cpp
+  ImplementAbstract.cpp
   ObjCLocalizeStringLiteral.cpp
   PopulateSwitch.cpp
   RawStringLiteral.cpp
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to