njames93 updated this revision to Diff 317984.
njames93 added a comment.

- Replace getFinalOverrides for a manual implementation, that method wasn't 
quite suited to what was needed w.r.t tracking access.
- Add support for template classes with no dependant bases.
- Add tests for template classes.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D94942/new/

https://reviews.llvm.org/D94942

Files:
  clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt
  clang-tools-extra/clangd/refactor/tweaks/ImplementAbstract.cpp
  clang-tools-extra/clangd/unittests/CMakeLists.txt
  clang-tools-extra/clangd/unittests/tweaks/ImplementAbstractTests.cpp

Index: clang-tools-extra/clangd/unittests/tweaks/ImplementAbstractTests.cpp
===================================================================
--- /dev/null
+++ clang-tools-extra/clangd/unittests/tweaks/ImplementAbstractTests.cpp
@@ -0,0 +1,349 @@
+//===-- ImplementAbstractTests.cpp ------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestTU.h"
+#include "TweakTesting.h"
+#include "gmock/gmock-matchers.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+using ::testing::Not;
+
+namespace clang {
+namespace clangd {
+namespace {
+
+TWEAK_TEST(ImplementAbstract);
+
+TEST_F(ImplementAbstractTest, TestUnavailable) {
+
+  StringRef Cases[]{
+      // Not a pure virtual method.
+      R"cpp(
+      class A {
+        virtual void Foo();
+      };
+      class ^B : public A {};
+    )cpp",
+      // Pure virtual method overridden in class.
+      R"cpp(
+      class A {
+        virtual void Foo() = 0;
+      };
+      class ^B : public A {
+        void Foo() override;
+      };
+    )cpp",
+      // Pure virtual method overridden in class with virtual keyword
+      R"cpp(
+      class A {
+        virtual void Foo() = 0;
+      };
+      class ^B : public A {
+        virtual void Foo() override;
+      };
+    )cpp",
+      // Pure virtual method overridden in class without override keyword
+      R"cpp(
+      class A {
+        virtual void Foo() = 0;
+      };
+      class ^B : public A {
+        void Foo();
+      };
+    )cpp",
+      // Pure virtual method overriden in base class.
+      R"cpp(
+      class A {
+        virtual void Foo() = 0;
+      };
+      class B : public A {
+        void Foo() override;
+      };
+      class ^C : public B {
+      };
+    )cpp"};
+  for (const auto &Case : Cases) {
+    EXPECT_THAT(Case, Not(isAvailable()));
+  }
+}
+
+TEST_F(ImplementAbstractTest, NormalAvailable) {
+  struct Case {
+    llvm::StringRef TestHeader;
+    llvm::StringRef TestSource;
+    llvm::StringRef ExpectedSource;
+  };
+
+  Case Cases[]{
+      {
+          R"cpp(
+      class A {
+        virtual void Foo() = 0;
+      };)cpp",
+          R"cpp(
+      class B : public A {^};
+    )cpp",
+          R"cpp(
+      class B : public A {
+void Foo() override;
+};
+    )cpp",
+      },
+      {
+          R"cpp(
+      class A {
+        public:
+        virtual void Foo() = 0;
+      };)cpp",
+          R"cpp(
+      class ^B : public A {};
+    )cpp",
+          R"cpp(
+      class B : public A {
+public:
+
+void Foo() override;
+};
+    )cpp",
+      },
+      {
+          R"cpp(
+      class A {
+        virtual void Foo(int Param) = 0;
+      };)cpp",
+          R"cpp(
+      class ^B : public A {};
+    )cpp",
+          R"cpp(
+      class B : public A {
+void Foo(int Param) override;
+};
+    )cpp",
+      },
+      {
+          R"cpp(
+      class A {
+        virtual void Foo(int Param) = 0;
+      };)cpp",
+          R"cpp(
+      struct ^B : public A {};
+    )cpp",
+          R"cpp(
+      struct B : public A {
+private:
+
+void Foo(int Param) override;
+};
+    )cpp",
+      },
+      {
+          R"cpp(
+      class A {
+        virtual void Foo(int Param) const volatile = 0;
+        public:
+        virtual void Bar(int Param) = 0;
+      };)cpp",
+          R"cpp(
+      class ^B : public A {
+        void Foo(int Param) const volatile override;
+      };
+    )cpp",
+          R"cpp(
+      class B : public A {
+        void Foo(int Param) const volatile override;
+      
+public:
+
+void Bar(int Param) override;
+};
+    )cpp",
+      },
+      {
+          R"cpp(
+           class A {
+        virtual void Foo() = 0;
+        virtual void Bar() = 0;
+      };
+      class B : public A {
+        void Foo() override;
+      };
+        )cpp",
+          R"cpp(
+          class ^C : public B {
+            virtual void Baz();
+          };
+        )cpp",
+          R"cpp(
+          class C : public B {
+            virtual void Baz();
+void Bar() override;
+
+          };
+        )cpp",
+      },
+      {
+          R"cpp(
+      class A {
+        virtual void Foo() = 0;
+      };)cpp",
+          R"cpp(
+      class ^B : public A {
+        ~B();
+      };
+    )cpp",
+          R"cpp(
+      class B : public A {
+void Foo() override;
+
+        ~B();
+      };
+    )cpp",
+      },
+      {
+          R"cpp(
+      class A {
+        virtual void Foo() = 0;
+        public:
+        virtual void Bar() = 0;
+      };)cpp",
+          R"cpp(
+      class ^B : public A {
+      };
+    )cpp",
+          R"cpp(
+      class B : public A {
+void Foo() override;
+
+      
+public:
+
+void Bar() override;
+};
+    )cpp",
+      },
+      {
+          R"cpp(
+      class A {
+        virtual void Foo() = 0;
+      };
+      struct B : public A {
+        virtual void Bar() = 0;
+      };)cpp",
+          R"cpp(
+      class ^C : public B {
+      };
+    )cpp",
+          R"cpp(
+      class C : public B {
+void Foo() override;
+
+      
+public:
+
+void Bar() override;
+};
+    )cpp",
+      },
+      {
+          R"cpp(
+            class A {
+              virtual void Foo() = 0;
+            };
+            struct B : public A {
+              virtual void Bar() = 0;
+            };)cpp",
+          R"cpp(
+            class ^C : private B {
+            };
+          )cpp",
+          R"cpp(
+            class C : private B {
+void Foo() override;
+void Bar() override;
+
+            };
+          )cpp",
+      },
+  };
+
+  for (const auto &Case : Cases) {
+    Header = Case.TestHeader.str();
+    EXPECT_EQ(apply(Case.TestSource), Case.ExpectedSource);
+  }
+}
+
+TEST_F(ImplementAbstractTest, TemplateUnavailable) {
+  StringRef Cases[]{
+      R"cpp(
+        template<typename T>
+        class A {
+          virtual void Foo() = 0;
+        };
+        template<typename T>
+        class ^B : public A<T> {};
+        )cpp",
+      R"cpp(
+        template<typename T>
+        class ^B : public T {};
+        )cpp",
+  };
+  for (const auto &Case : Cases) {
+    EXPECT_THAT(Case, Not(isAvailable()));
+  }
+}
+
+TEST_F(ImplementAbstractTest, TemplateAvailable) {
+  struct Case {
+    llvm::StringRef TestHeader;
+    llvm::StringRef TestSource;
+    llvm::StringRef ExpectedSource;
+  };
+  Case Cases[]{
+      {
+          R"cpp(
+            template<typename T>
+            class A {
+              virtual void Foo() = 0;
+            };
+            )cpp",
+          R"cpp(
+            class ^B : public A<int> {};
+            )cpp",
+          R"cpp(
+            class B : public A<int> {
+void Foo() override;
+};
+            )cpp",
+      },
+      {
+          R"cpp(
+            class A {
+              virtual void Foo() = 0;
+            };)cpp",
+          R"cpp(
+            template<typename T>
+            class ^B : public A {};
+            )cpp",
+          R"cpp(
+            template<typename T>
+            class B : public A {
+void Foo() override;
+};
+            )cpp",
+      },
+  };
+  for (const auto &Case : Cases) {
+    Header = Case.TestHeader.str();
+    EXPECT_EQ(apply(Case.TestSource), Case.ExpectedSource);
+  }
+}
+
+} // namespace
+} // namespace clangd
+} // namespace clang
Index: clang-tools-extra/clangd/unittests/CMakeLists.txt
===================================================================
--- clang-tools-extra/clangd/unittests/CMakeLists.txt
+++ clang-tools-extra/clangd/unittests/CMakeLists.txt
@@ -118,6 +118,7 @@
   tweaks/ExpandMacroTests.cpp
   tweaks/ExtractFunctionTests.cpp
   tweaks/ExtractVariableTests.cpp
+  tweaks/ImplementAbstractTests.cpp
   tweaks/ObjCLocalizeStringLiteralTests.cpp
   tweaks/PopulateSwitchTests.cpp
   tweaks/RawStringLiteralTests.cpp
Index: clang-tools-extra/clangd/refactor/tweaks/ImplementAbstract.cpp
===================================================================
--- /dev/null
+++ clang-tools-extra/clangd/refactor/tweaks/ImplementAbstract.cpp
@@ -0,0 +1,296 @@
+//===--- ImplementAbstract.cpp -----------------------------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "refactor/Tweak.h"
+#include "support/Logger.h"
+#include "llvm/ADT/PointerIntPair.h"
+#include "llvm/ADT/SmallPtrSet.h"
+
+namespace clang {
+namespace clangd {
+using MethodAndAccess =
+    llvm::PointerIntPair<const CXXMethodDecl *, 2, AccessSpecifier>;
+
+namespace {
+class ImplementAbstract : public Tweak {
+public:
+  const char *id() const override;
+
+  bool prepare(const Selection &Inputs) override;
+  Expected<Effect> apply(const Selection &Inputs) override;
+  std::string title() const override {
+    return "Implement pure virtual methods";
+  }
+  llvm::StringLiteral kind() const override {
+    return CodeAction::REFACTOR_KIND;
+  }
+
+private:
+  const CXXRecordDecl *Selected;
+  const CXXRecordDecl *Recent;
+  std::vector<MethodAndAccess> PureVirtualMethods;
+};
+
+AccessSpecifier getConstrained(AccessSpecifier InheritSpecifier,
+                               AccessSpecifier DefinedAs) {
+  return std::max(InheritSpecifier, DefinedAs);
+}
+
+bool collectPureVirtual(const CXXRecordDecl &Record,
+                        std::vector<MethodAndAccess> &Results,
+                        AccessSpecifier Access,
+                        llvm::SmallPtrSetImpl<const CXXMethodDecl *> &Overrides,
+                        bool IsRoot) {
+  if (Record.getNumBases() == 0) {
+    // If there are no base classes, don't bother populating the Overrides set
+    // as we'll never read it.
+    for (const CXXMethodDecl *Method : Record.methods()) {
+      if (!Method->isPure())
+        continue;
+      if (IsRoot)
+        return true;
+      if (!Overrides.contains(Method)) {
+        // Method hasn't been overridden in any derived class.
+        Results.emplace_back(Method,
+                             getConstrained(Access, Method->getAccess()));
+      }
+    }
+    return false;
+  }
+
+  for (const CXXMethodDecl *Method : Record.methods()) {
+    if (!Method->isVirtual())
+      continue;
+    if (IsRoot && Method->isPure())
+      return true;
+    for (const auto *Overriding : Method->overridden_methods())
+      Overrides.insert(Overriding);
+  }
+  for (auto Base : Record.bases()) {
+    const RecordType *RT = Base.getType()->getAs<RecordType>();
+    if (!RT)
+      // Probably a dependent base, just error out.
+      return true;
+    const CXXRecordDecl *BaseDecl = cast<CXXRecordDecl>(RT->getDecl());
+    if (!BaseDecl->isPolymorphic())
+      continue;
+    if (collectPureVirtual(*BaseDecl, Results,
+                           getConstrained(Access, Base.getAccessSpecifier()),
+                           Overrides, false))
+      // Propergate any error back up.
+      return true;
+  }
+  // Add the Pure methods from this class after traversing the bases, this means
+  // they will appear after in the
+  for (const CXXMethodDecl *Method : Record.methods()) {
+    if (!Method->isPure())
+      continue;
+    if (!Overrides.contains(Method)) {
+      // Method hasn't been overridden in any derived class.
+      Results.emplace_back(Method, getConstrained(Access, Method->getAccess()));
+    }
+  }
+  return false;
+}
+
+static const CXXRecordDecl *
+getSelectedRecord(const SelectionTree::Node *SelNode) {
+  if (!SelNode)
+    return nullptr;
+  const DynTypedNode &AstNode = SelNode->ASTNode;
+  return AstNode.get<CXXRecordDecl>();
+}
+
+bool ImplementAbstract::prepare(const Selection &Inputs) {
+  // FIXME: This method won't return the class when the caret in the body of the
+  // class. So the only way to get the tweak offered is to be be touching the
+  // marked ranges. It would be nicer if this was offered if cursor was inside
+  // the class (but perhaps not inside the classes decls).
+  //  [[class]]  [[Derived]]  [[:]]  [[public]]  Base  [[{]]
+  //   ^
+  // [[}]];
+  Selected = getSelectedRecord(Inputs.ASTSelection.commonAncestor());
+  if (!Selected)
+    return false;
+
+  // Some sanity checks before we try.
+  if (!Selected->isThisDeclarationADefinition())
+    return false;
+  if (!Selected->isClass() && !Selected->isStruct())
+    return false;
+  if (Selected->hasAnyDependentBases() || Selected->getNumBases() == 0)
+    return false;
+  // We should check for abstract, but that prevents working on template classes
+  // that don't have any dependent bases.
+  if (!Selected->isPolymorphic())
+    return false;
+
+  Recent = Selected->getMostRecentDecl();
+
+  llvm::SmallPtrSet<const CXXMethodDecl *, 16> Overrides;
+  if (collectPureVirtual(*Selected, PureVirtualMethods, AS_public, Overrides,
+                         true))
+    return false;
+  return !PureVirtualMethods.empty();
+}
+
+static void printMethods(llvm::raw_ostream &Out,
+                         ArrayRef<const CXXMethodDecl *> Items,
+                         const CXXRecordDecl *PrintContext,
+                         StringRef AccessSpec = {}) {
+  class PrintCB : public PrintingCallbacks {
+  public:
+    PrintCB(const DeclContext *CurContext) : CurContext(CurContext) {}
+    virtual ~PrintCB() {}
+    bool isScopeVisible(const DeclContext *DC) const override {
+      return DC->Encloses(CurContext);
+    }
+
+  private:
+    const DeclContext *CurContext;
+  };
+  PrintCB Callbacks(PrintContext);
+  auto Policy = PrintContext->getASTContext().getPrintingPolicy();
+  Policy.SuppressScope = false;
+  Policy.Callbacks = &Callbacks;
+  if (!AccessSpec.empty())
+    Out << "\n" << AccessSpec << ":\n";
+  Out << "\n";
+  for (const CXXMethodDecl *Method : Items) {
+    Method->getReturnType().print(Out, Policy);
+    Out << ' ';
+    Out << Method->getNameAsString() << "(";
+    bool IsFirst = true;
+    for (const auto &Param : Method->parameters()) {
+      if (!IsFirst)
+        Out << ", ";
+      else
+        IsFirst = false;
+      Param->print(Out, Policy);
+    }
+    Out << ") ";
+    if (Method->isConst())
+      Out << "const ";
+    if (Method->isVolatile())
+      Out << "volatile ";
+    // Always suggest `override` over `final`.
+    Out << "override;\n";
+  }
+}
+
+Expected<Tweak::Effect> ImplementAbstract::apply(const Selection &Inputs) {
+  llvm::SmallVector<const CXXMethodDecl *, 4> GroupedAccessMethods[3];
+
+  for (const MethodAndAccess &PVM : PureVirtualMethods) {
+    GroupedAccessMethods[PVM.getInt()].push_back(PVM.getPointer());
+  }
+
+  // We should have at least one pure virtual method to add.
+  assert(llvm::any_of(
+      GroupedAccessMethods,
+      [](ArrayRef<const CXXMethodDecl *> Array) { return !Array.empty(); }));
+
+  struct InsertionDetail {
+    SourceLocation Loc = {};
+    bool RefersToMethod = false;
+  };
+
+  using DetailAndAccess = std::pair<InsertionDetail, AccessSpecifier>;
+  SmallVector<DetailAndAccess, 3> InsertionPoints;
+
+  auto GetDetailForAccess = [&](AccessSpecifier Spec) -> InsertionDetail & {
+    assert(Spec != AS_none);
+    for (DetailAndAccess &Item : InsertionPoints) {
+      if (Item.second == Spec)
+        return Item.first;
+    }
+    return InsertionPoints.emplace_back(InsertionDetail{}, Spec).first;
+  };
+
+  // FIXME: This is a little hacky but EndLoc of a function decl is the start of
+  // the last token not including a semi-colon if its just a declaration. This
+  // skips past the last token plus one just incase there is a semi-colon.
+  // Should really find a nicer way around this.
+  auto Next = [&](SourceLocation Loc) {
+    return Loc.getLocWithOffset(
+        Lexer::MeasureTokenLength(Loc, Inputs.AST->getSourceManager(),
+                                  Inputs.AST->getLangOpts()) +
+        1);
+  };
+  // This whole block is designed to get an insertion point after the last
+  // method has been declared with each access specifier. Doing this ensures we
+  // keep the same visibility for implemented methods without the need to add
+  // unnecessary access specifiers.
+  for (auto *Decl : Selected->decls()) {
+    // Ignore things like compiler generated special member functions.
+    if (Decl->isImplicit())
+      continue;
+    // Hack to try and leave the destructor as last method in a block.
+    if (isa<CXXDestructorDecl>(Decl))
+      continue;
+    InsertionDetail &Detail = GetDetailForAccess(Decl->getAccess());
+    if (isa<CXXMethodDecl>(Decl)) {
+      Detail.Loc = Next(Decl->getSourceRange().getEnd());
+      Detail.RefersToMethod = true;
+    } else if (!Detail.RefersToMethod) {
+      // Last decl with this access wasn't method decl.
+      Detail.Loc = Next(Decl->getSourceRange().getEnd());
+    }
+  }
+  if (InsertionPoints.empty()) {
+    // No non-implicit declarations in the body, use the default access for the
+    // first potential insertion.
+    GetDetailForAccess(Selected->isClass() ? AS_private : AS_public) =
+        InsertionDetail{
+            Selected->getBraceRange().getBegin().getLocWithOffset(1), true};
+  }
+
+  SmallString<256> Buffer;
+  llvm::raw_svector_ostream OS(Buffer);
+  tooling::Replacements Replacements;
+  for (auto &Item : InsertionPoints) {
+    assert(Item.first.Loc.isValid());
+    llvm::SmallVectorImpl<const CXXMethodDecl *> &GroupedMethods =
+        GroupedAccessMethods[Item.second];
+    if (GroupedMethods.empty())
+      continue;
+    printMethods(OS, GroupedMethods, Selected);
+    if (auto Err = Replacements.add(tooling::Replacement(
+            Inputs.AST->getSourceManager(), Item.first.Loc, 0, Buffer))) {
+      return std::move(Err);
+    }
+    // Clear the methods as in the fallback loop we don't want to print them
+    // again.
+    GroupedMethods.clear();
+    Buffer.clear();
+  }
+
+  // Any access specifiers not convered can be added in one insertion.
+  for (AccessSpecifier Spec : {AS_public, AS_protected, AS_private}) {
+    llvm::SmallVectorImpl<const CXXMethodDecl *> &GroupedMethods =
+        GroupedAccessMethods[Spec];
+    if (GroupedMethods.empty())
+      continue;
+    printMethods(OS, GroupedMethods, Selected, getAccessSpelling(Spec));
+  }
+  if (!Buffer.empty()) {
+    if (auto Err = Replacements.add(tooling::Replacement(
+            Inputs.AST->getSourceManager(), Selected->getBraceRange().getEnd(),
+            0, Buffer))) {
+      return std::move(Err);
+    }
+  }
+  return Effect::mainFileEdit(Inputs.AST->getASTContext().getSourceManager(),
+                              std::move(Replacements));
+}
+
+REGISTER_TWEAK(ImplementAbstract)
+
+} // namespace
+} // namespace clangd
+} // namespace clang
Index: clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt
===================================================================
--- clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt
+++ clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt
@@ -21,6 +21,7 @@
   ExpandMacro.cpp
   ExtractFunction.cpp
   ExtractVariable.cpp
+  ImplementAbstract.cpp
   ObjCLocalizeStringLiteral.cpp
   PopulateSwitch.cpp
   RawStringLiteral.cpp
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to