https://github.com/Venyla updated 
https://github.com/llvm/llvm-project/pull/69693

>From 2b1f56a758d397649c94717f4a030c04a532bde7 Mon Sep 17 00:00:00 2001
From: Vina Zahnd <vina.za...@gmail.com>
Date: Fri, 20 Oct 2023 10:01:54 +0200
Subject: [PATCH] [clangd] Add tweak to inline concept requirements

Co-authored-by: Jeremy Stucki <d...@jeremystucki.ch>
---
 .../clangd/refactor/tweaks/CMakeLists.txt     |   1 +
 .../tweaks/InlineConceptRequirement.cpp       | 262 ++++++++++++++++++
 .../clangd/unittests/CMakeLists.txt           |   1 +
 .../tweaks/InlineConceptRequirement.cpp       |  94 +++++++
 4 files changed, 358 insertions(+)
 create mode 100644 
clang-tools-extra/clangd/refactor/tweaks/InlineConceptRequirement.cpp
 create mode 100644 
clang-tools-extra/clangd/unittests/tweaks/InlineConceptRequirement.cpp

diff --git a/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt 
b/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt
index 526a073f619ea34..b01053faf738a90 100644
--- a/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt
+++ b/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt
@@ -21,6 +21,7 @@ add_clang_library(clangDaemonTweaks OBJECT
   ExpandMacro.cpp
   ExtractFunction.cpp
   ExtractVariable.cpp
+  InlineConceptRequirement.cpp
   MemberwiseConstructor.cpp
   ObjCLocalizeStringLiteral.cpp
   ObjCMemberwiseInitializer.cpp
diff --git 
a/clang-tools-extra/clangd/refactor/tweaks/InlineConceptRequirement.cpp 
b/clang-tools-extra/clangd/refactor/tweaks/InlineConceptRequirement.cpp
new file mode 100644
index 000000000000000..b6c0237703c3474
--- /dev/null
+++ b/clang-tools-extra/clangd/refactor/tweaks/InlineConceptRequirement.cpp
@@ -0,0 +1,262 @@
+//===--- InlineConceptRequirement.cpp ----------------------------*- 
C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "ParsedAST.h"
+#include "SourceCode.h"
+#include "refactor/Tweak.h"
+#include "support/Logger.h"
+#include "clang/AST/ASTContext.h"
+#include "clang/AST/ExprConcepts.h"
+#include "clang/Tooling/Core/Replacement.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/Error.h"
+
+namespace clang {
+namespace clangd {
+namespace {
+/// Inlines a concept requirement.
+///
+/// Before:
+///   template <typename T> void f(T) requires foo<T> {}
+///                                            ^^^^^^
+/// After:
+///   template <foo T> void f(T) {}
+class InlineConceptRequirement : public Tweak {
+public:
+  const char *id() const final;
+
+  auto prepare(const Selection &Inputs) -> bool override;
+  auto apply(const Selection &Inputs) -> Expected<Effect> override;
+  auto title() const -> std::string override {
+    return "Inline concept requirement";
+  }
+  auto kind() const -> llvm::StringLiteral override {
+    return CodeAction::REFACTOR_KIND;
+  }
+
+private:
+  const ConceptSpecializationExpr *ConceptSpecializationExpression;
+  const TemplateTypeParmDecl *TemplateTypeParameterDeclaration;
+  const syntax::Token *RequiresToken;
+
+  static auto getTemplateParameterIndexOfTemplateArgument(
+      const TemplateArgument &TemplateArgument) -> std::optional<int>;
+  auto generateRequiresReplacement(ASTContext &)
+      -> std::variant<tooling::Replacement, llvm::Error>;
+  auto generateRequiresTokenReplacement(const syntax::TokenBuffer &)
+      -> tooling::Replacement;
+  auto generateTemplateParameterReplacement(ASTContext &Context)
+      -> tooling::Replacement;
+
+  static auto findToken(const ParsedAST *, const SourceRange &,
+                        const tok::TokenKind) -> const syntax::Token *;
+
+  template <typename T, typename NodeKind>
+  static auto findNode(const SelectionTree::Node &Root)
+      -> std::tuple<const T *, const SelectionTree::Node *>;
+
+  template <typename T>
+  static auto findExpression(const SelectionTree::Node &Root)
+      -> std::tuple<const T *, const SelectionTree::Node *> {
+    return findNode<T, Expr>(Root);
+  }
+
+  template <typename T>
+  static auto findDeclaration(const SelectionTree::Node &Root)
+      -> std::tuple<const T *, const SelectionTree::Node *> {
+    return findNode<T, Decl>(Root);
+  }
+};
+
+REGISTER_TWEAK(InlineConceptRequirement)
+
+auto InlineConceptRequirement::prepare(const Selection &Inputs) -> bool {
+  // Check if C++ version is 20 or higher
+  if (!Inputs.AST->getLangOpts().CPlusPlus20)
+    return false;
+
+  const auto *Root = Inputs.ASTSelection.commonAncestor();
+  if (!Root)
+    return false;
+
+  const SelectionTree::Node *ConceptSpecializationExpressionTreeNode;
+  std::tie(ConceptSpecializationExpression,
+           ConceptSpecializationExpressionTreeNode) =
+      findExpression<ConceptSpecializationExpr>(*Root);
+  if (!ConceptSpecializationExpression)
+    return false;
+
+  // Only allow concepts that are direct children of function template
+  // declarations or function declarations. This excludes conjunctions of
+  // concepts which are not handled.
+  const auto *ParentDeclaration =
+      ConceptSpecializationExpressionTreeNode->Parent->ASTNode.get<Decl>();
+  if (!isa_and_nonnull<FunctionTemplateDecl>(ParentDeclaration) &&
+      !isa_and_nonnull<FunctionDecl>(ParentDeclaration))
+    return false;
+
+  const FunctionTemplateDecl *FunctionTemplateDeclaration =
+      std::get<0>(findDeclaration<FunctionTemplateDecl>(*Root));
+  if (!FunctionTemplateDeclaration)
+    return false;
+
+  auto TemplateArguments =
+      ConceptSpecializationExpression->getTemplateArguments();
+  if (TemplateArguments.size() != 1)
+    return false;
+
+  auto TemplateParameterIndex =
+      getTemplateParameterIndexOfTemplateArgument(TemplateArguments[0]);
+  if (!TemplateParameterIndex)
+    return false;
+
+  TemplateTypeParameterDeclaration = dyn_cast_or_null<TemplateTypeParmDecl>(
+      FunctionTemplateDeclaration->getTemplateParameters()->getParam(
+          *TemplateParameterIndex));
+  if (!TemplateTypeParameterDeclaration->wasDeclaredWithTypename())
+    return false;
+
+  RequiresToken =
+      findToken(Inputs.AST, FunctionTemplateDeclaration->getSourceRange(),
+                tok::kw_requires);
+  if (!RequiresToken)
+    return false;
+
+  return true;
+}
+
+auto InlineConceptRequirement::apply(const Selection &Inputs)
+    -> Expected<Tweak::Effect> {
+  auto &Context = Inputs.AST->getASTContext();
+  auto &TokenBuffer = Inputs.AST->getTokens();
+
+  tooling::Replacements Replacements{};
+
+  if (auto Err =
+          Replacements.add(generateTemplateParameterReplacement(Context)))
+    return Err;
+
+  auto RequiresReplacement = generateRequiresReplacement(Context);
+
+  if (std::holds_alternative<llvm::Error>(RequiresReplacement))
+    return std::move(std::get<llvm::Error>(RequiresReplacement));
+
+  if (auto Err =
+          
Replacements.add(std::get<tooling::Replacement>(RequiresReplacement)))
+    return Err;
+
+  if (auto Err =
+          Replacements.add(generateRequiresTokenReplacement(TokenBuffer)))
+    return Err;
+
+  return Effect::mainFileEdit(Context.getSourceManager(), Replacements);
+}
+
+auto InlineConceptRequirement::getTemplateParameterIndexOfTemplateArgument(
+    const TemplateArgument &TemplateArgument) -> std::optional<int> {
+  if (TemplateArgument.getKind() != TemplateArgument.Type)
+    return {};
+
+  auto TemplateArgumentType = TemplateArgument.getAsType();
+  if (!TemplateArgumentType->isTemplateTypeParmType())
+    return {};
+
+  const auto *TemplateTypeParameterType =
+      TemplateArgumentType->getAs<TemplateTypeParmType>();
+  if (!TemplateTypeParameterType)
+    return {};
+
+  return TemplateTypeParameterType->getIndex();
+}
+
+auto InlineConceptRequirement::generateRequiresReplacement(ASTContext &Context)
+    -> std::variant<tooling::Replacement, llvm::Error> {
+  auto &SourceManager = Context.getSourceManager();
+
+  auto RequiresRange =
+      toHalfOpenFileRange(SourceManager, Context.getLangOpts(),
+                          ConceptSpecializationExpression->getSourceRange());
+  if (!RequiresRange)
+    return error("Could not obtain range of the 'requires' branch. Macros?");
+
+  auto RequiresCode = toSourceCode(SourceManager, *RequiresRange);
+
+  return tooling::Replacement(SourceManager, RequiresRange->getBegin(),
+                              RequiresCode.size(), "");
+}
+
+auto InlineConceptRequirement::generateRequiresTokenReplacement(
+    const syntax::TokenBuffer &TokenBuffer) -> tooling::Replacement {
+  auto &SourceManager = TokenBuffer.sourceManager();
+
+  auto Spelling =
+      TokenBuffer.spelledForExpanded(llvm::ArrayRef(*RequiresToken));
+
+  auto DeletionRange =
+      syntax::Token::range(SourceManager, Spelling->front(), Spelling->back())
+          .toCharRange(SourceManager);
+
+  return tooling::Replacement(SourceManager, DeletionRange, "");
+}
+
+auto InlineConceptRequirement::generateTemplateParameterReplacement(
+    ASTContext &Context) -> tooling::Replacement {
+  auto &SourceManager = Context.getSourceManager();
+
+  auto ConceptName = ConceptSpecializationExpression->getNamedConcept()
+                         ->getQualifiedNameAsString();
+
+  auto TemplateParameterName =
+      TemplateTypeParameterDeclaration->getQualifiedNameAsString();
+
+  auto TemplateParameterReplacement = ConceptName + ' ' + 
TemplateParameterName;
+
+  auto TemplateParameterRange =
+      toHalfOpenFileRange(SourceManager, Context.getLangOpts(),
+                          TemplateTypeParameterDeclaration->getSourceRange());
+
+  auto SourceCode = toSourceCode(SourceManager, *TemplateParameterRange);
+
+  return tooling::Replacement(Context.getSourceManager(),
+                              TemplateParameterRange->getBegin(),
+                              SourceCode.size(), TemplateParameterReplacement);
+}
+
+auto clang::clangd::InlineConceptRequirement::findToken(
+    const ParsedAST *AST, const SourceRange &SourceRange,
+    const tok::TokenKind TokenKind) -> const syntax::Token * {
+  auto &TokenBuffer = AST->getTokens();
+  const auto &Tokens = TokenBuffer.expandedTokens(SourceRange);
+
+  const auto Predicate = [TokenKind](const auto &Token) {
+    return Token.kind() == TokenKind;
+  };
+
+  auto It = std::find_if(Tokens.begin(), Tokens.end(), Predicate);
+
+  if (It == Tokens.end())
+    return nullptr;
+
+  return It;
+}
+
+template <typename T, typename NodeKind>
+auto InlineConceptRequirement::findNode(const SelectionTree::Node &Root)
+    -> std::tuple<const T *, const SelectionTree::Node *> {
+
+  for (const auto *Node = &Root; Node; Node = Node->Parent) {
+    if (const T *Result = dyn_cast_or_null<T>(Node->ASTNode.get<NodeKind>()))
+      return {Result, Node};
+  }
+
+  return {};
+}
+
+} // namespace
+} // namespace clangd
+} // namespace clang
diff --git a/clang-tools-extra/clangd/unittests/CMakeLists.txt 
b/clang-tools-extra/clangd/unittests/CMakeLists.txt
index 8d02b91fdd71669..de9376e439a4a3f 100644
--- a/clang-tools-extra/clangd/unittests/CMakeLists.txt
+++ b/clang-tools-extra/clangd/unittests/CMakeLists.txt
@@ -125,6 +125,7 @@ add_unittest(ClangdUnitTests ClangdTests
   tweaks/ExpandMacroTests.cpp
   tweaks/ExtractFunctionTests.cpp
   tweaks/ExtractVariableTests.cpp
+  tweaks/InlineConceptRequirement.cpp
   tweaks/MemberwiseConstructorTests.cpp
   tweaks/ObjCLocalizeStringLiteralTests.cpp
   tweaks/ObjCMemberwiseInitializerTests.cpp
diff --git 
a/clang-tools-extra/clangd/unittests/tweaks/InlineConceptRequirement.cpp 
b/clang-tools-extra/clangd/unittests/tweaks/InlineConceptRequirement.cpp
new file mode 100644
index 000000000000000..648a08c434467d3
--- /dev/null
+++ b/clang-tools-extra/clangd/unittests/tweaks/InlineConceptRequirement.cpp
@@ -0,0 +1,94 @@
+//===-- InlineConceptRequirement.cpp ----------------------------*- C++ 
-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TweakTesting.h"
+#include "gtest/gtest.h"
+
+namespace clang {
+namespace clangd {
+namespace {
+
+TWEAK_TEST(InlineConceptRequirement);
+
+TEST_F(InlineConceptRequirementTest, Test) {
+  Header = R"cpp(
+      template <typename T>
+      concept foo = true;
+
+      template <typename T>
+      concept bar = true;
+
+      template <typename T, typename U>
+      concept baz = true;
+    )cpp";
+
+  ExtraArgs = {"-std=c++20"};
+
+  //
+  // Extra spaces are expected and will be stripped by the formatter.
+  //
+
+  EXPECT_EQ(
+      apply("template <typename T, typename U> void f(T) requires f^oo<U> {}"),
+      "template <typename T, foo U> void f(T)   {}");
+
+  EXPECT_EQ(
+      apply("template <typename T, typename U> requires foo<^T> void f(T) {}"),
+      "template <foo T, typename U>   void f(T) {}");
+
+  EXPECT_EQ(apply("template <template <typename> class FooBar, typename T>"
+                  "void f() requires foo<^T> {}"),
+            "template <template <typename> class FooBar, foo T> void f()   
{}");
+
+  EXPECT_AVAILABLE(R"cpp(
+      template <typename T> void f(T)
+        requires ^f^o^o^<^T^> {}
+    )cpp");
+
+  EXPECT_AVAILABLE(R"cpp(
+      template <typename T> requires ^f^o^o^<^T^>
+      void f(T) {}
+    )cpp");
+
+  EXPECT_AVAILABLE(R"cpp(
+      template <typename T, typename U> void f(T)
+        requires ^f^o^o^<^T^> {}
+    )cpp");
+
+  EXPECT_AVAILABLE(R"cpp(
+      template <template <typename> class FooBar, typename T>
+      void foobar() requires ^f^o^o^<^T^>
+      {}
+    )cpp");
+
+  EXPECT_UNAVAILABLE(R"cpp(
+      template <bar T> void f(T)
+        requires ^f^o^o^<^T^> {}
+    )cpp");
+
+  EXPECT_UNAVAILABLE(R"cpp(
+      template <typename T, typename U> void f(T, U)
+        requires ^b^a^z^<^T^,^ ^U^> {}
+    )cpp");
+
+  EXPECT_UNAVAILABLE(R"cpp(
+      template <typename T> void f(T)
+        requires ^f^o^o^<^T^>^ ^&^&^ ^b^a^r^<^T^> {}
+    )cpp");
+
+  EXPECT_UNAVAILABLE(R"cpp(
+      template <typename T>
+      concept ^f^o^o^b^a^r = requires(^T^ ^x^) {
+        {x} -> ^f^o^o^;
+      };
+    )cpp");
+}
+
+} // namespace
+} // namespace clangd
+} // namespace clang

_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to