arphaman updated this revision to Diff 110563.
arphaman edited the summary of this revision.
arphaman added a comment.

- Simplify error/diagnostic handling. Use `DiagnosticOr` instead of 
`DiagOr<Expected>`.
- Simplify the code for the selection requirements by removing lambda deducers 
and instead using special classes for requirements instead of lambdas/functions.
- Rename `selectionRequirement` to `requiredSelection`


Repository:
  rL LLVM

https://reviews.llvm.org/D36075

Files:
  include/clang/Basic/AllDiagnostics.h
  include/clang/Basic/CMakeLists.txt
  include/clang/Basic/Diagnostic.td
  include/clang/Basic/DiagnosticIDs.h
  include/clang/Basic/DiagnosticOr.h
  include/clang/Basic/DiagnosticRefactoringKinds.td
  include/clang/Basic/LLVM.h
  include/clang/Tooling/Refactoring/AtomicChange.h
  include/clang/Tooling/Refactoring/RefactoringActionRules.h
  include/clang/Tooling/Refactoring/RefactoringDiagnostic.h
  include/clang/Tooling/Refactoring/RefactoringOperationController.h
  include/clang/Tooling/Refactoring/RefactoringResult.h
  include/clang/Tooling/Refactoring/SourceSelectionConstraints.h
  lib/Basic/DiagnosticIDs.cpp
  tools/diagtool/DiagnosticNames.cpp
  unittests/Tooling/CMakeLists.txt
  unittests/Tooling/RefactoringActionRulesTest.cpp

Index: unittests/Tooling/RefactoringActionRulesTest.cpp
===================================================================
--- /dev/null
+++ unittests/Tooling/RefactoringActionRulesTest.cpp
@@ -0,0 +1,157 @@
+//===- unittest/Tooling/RefactoringTestActionRulesTest.cpp ----------------===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include "ReplacementTest.h"
+#include "RewriterTestContext.h"
+#include "clang/Tooling/Refactoring.h"
+#include "clang/Tooling/Refactoring/RefactoringActionRules.h"
+#include "clang/Tooling/Tooling.h"
+#include "llvm/Support/Errc.h"
+#include "gtest/gtest.h"
+
+using namespace clang;
+using namespace tooling;
+using namespace refactoring_action_rules;
+
+namespace {
+
+class RefactoringActionRulesTest : public ::testing::Test {
+protected:
+  void SetUp() override {
+    Context.Sources.setMainFileID(
+        Context.createInMemoryFile("input.cpp", DefaultCode));
+  }
+
+  RewriterTestContext Context;
+  std::string DefaultCode = std::string(100, 'a');
+};
+
+TEST_F(RefactoringActionRulesTest, MyFirstRefactoringRule) {
+  auto ReplaceAWithB =
+      [](std::pair<selection::SourceSelectionRange, int> Selection)
+      -> Expected<RefactoringResult> {
+    const SourceManager &SM = Selection.first.getSources();
+    SourceLocation Loc = Selection.first.getRange().getBegin().getLocWithOffset(
+        Selection.second);
+    AtomicChange Change(SM, Loc);
+    llvm::Error E = Change.replace(SM, Loc, 1, "b");
+    if (E)
+      return std::move(E);
+    return Change;
+  };
+  class SelectionRequirement : public selection::Requirement {
+  public:
+    std::pair<selection::SourceSelectionRange, int>
+    evaluateSelection(selection::SourceSelectionRange Selection) const {
+      return std::make_pair(Selection, 20);
+    }
+  };
+  auto Rule = apply(ReplaceAWithB, requiredSelection(SelectionRequirement()));
+
+  // When the requirements are satisifed, the rule's function must be invoked.
+  {
+    RefactoringOperationController Operation(Context.Sources);
+    SourceLocation Cursor =
+        Context.Sources.getLocForStartOfFile(Context.Sources.getMainFileID())
+            .getLocWithOffset(10);
+    Operation.setSelectionRange({Cursor, Cursor});
+
+    DiagnosticOr<RefactoringResult> DiagOrResult = Rule->perform(Operation);
+    ASSERT_FALSE(!DiagOrResult);
+    RefactoringResult Result = std::move(*DiagOrResult);
+    ASSERT_EQ(Result.getKind(), RefactoringResult::AtomicChanges);
+    ASSERT_EQ(Result.getChanges().size(), 1u);
+    std::string YAMLString = Result.getChanges()[0].toYAMLString();
+
+    ASSERT_STREQ("---\n"
+                 "Key:             'input.cpp:30'\n"
+                 "FilePath:        input.cpp\n"
+                 "Error:           ''\n"
+                 "InsertedHeaders: \n"
+                 "RemovedHeaders:  \n"
+                 "Replacements:    \n" // Extra whitespace here!
+                 "  - FilePath:        input.cpp\n"
+                 "    Offset:          30\n"
+                 "    Length:          1\n"
+                 "    ReplacementText: b\n"
+                 "...\n",
+                 YAMLString.c_str());
+  }
+
+  // When one of the requirements is not satisfied, perform should return either
+  // None or a valid diagnostic.
+  {
+    RefactoringOperationController Operation(Context.Sources);
+    DiagnosticOr<RefactoringResult> DiagOrResult = Rule->perform(Operation);
+
+    // A failure to select returns the invalidSelectionError.
+    ASSERT_TRUE(!DiagOrResult);
+    EXPECT_EQ(DiagOrResult.getDiagnostic().first, SourceLocation());
+    EXPECT_EQ(DiagOrResult.getDiagnostic().second.getDiagID(), 0u);
+  }
+}
+
+TEST_F(RefactoringActionRulesTest, ReturnError) {
+  Expected<RefactoringResult> (*Func)(selection::SourceSelectionRange) =
+      [](selection::SourceSelectionRange) -> Expected<RefactoringResult> {
+    return llvm::make_error<llvm::StringError>(
+        "Error", std::make_error_code(std::errc::bad_message));
+  };
+  auto Rule =
+      apply(Func, requiredSelection(
+                      selection::identity<selection::SourceSelectionRange>()));
+
+  RefactoringOperationController Operation(Context.Sources);
+  SourceLocation Cursor =
+      Context.Sources.getLocForStartOfFile(Context.Sources.getMainFileID());
+  Operation.setSelectionRange({Cursor, Cursor});
+  DiagnosticOr<RefactoringResult> Result = Rule->perform(Operation);
+
+  ASSERT_TRUE(!Result);
+  SmallString<100> Str;
+  Result.getDiagnostic().second.EmitToString(Context.Diagnostics, Str);
+  EXPECT_EQ(Str.str(), "Error");
+}
+
+TEST_F(RefactoringActionRulesTest, ReturnInitiationDiagnostic) {
+  unsigned DiagID = Context.Diagnostics.getCustomDiagID(
+      DiagnosticsEngine::Error, "Diagnostic: %0");
+  RefactoringOperationController Operation(Context.Sources);
+  PartialDiagnostic Diag(DiagID, Operation.getDiagnosticStorage());
+
+  class SelectionRequirement : public selection::Requirement {
+  public:
+    const PartialDiagnostic &Diag;
+    SelectionRequirement(const PartialDiagnostic &Diag) : Diag(Diag) {}
+
+    DiagnosticOr<int>
+    evaluateSelection(selection::SourceSelectionRange Selection) const {
+      return PartialDiagnosticAt(Selection.getRange().getBegin(),
+                                 Diag << "test");
+    }
+  };
+  auto Rule = apply(
+      [](int) -> Expected<RefactoringResult> {
+        llvm::report_fatal_error("Should not run!");
+      },
+      requiredSelection(SelectionRequirement(Diag)));
+
+  SourceLocation Cursor =
+      Context.Sources.getLocForStartOfFile(Context.Sources.getMainFileID());
+  Operation.setSelectionRange({Cursor, Cursor});
+  DiagnosticOr<RefactoringResult> Result = Rule->perform(Operation);
+
+  ASSERT_TRUE(!Result);
+  EXPECT_EQ(Result.getDiagnostic().first, Cursor);
+  SmallString<100> Str;
+  Result.getDiagnostic().second.EmitToString(Context.Diagnostics, Str);
+  EXPECT_EQ(Str.str(), "Diagnostic: test");
+}
+
+} // end anonymous namespace
Index: unittests/Tooling/CMakeLists.txt
===================================================================
--- unittests/Tooling/CMakeLists.txt
+++ unittests/Tooling/CMakeLists.txt
@@ -23,6 +23,7 @@
   RecursiveASTVisitorTestDeclVisitor.cpp
   RecursiveASTVisitorTestExprVisitor.cpp
   RecursiveASTVisitorTestTypeLocVisitor.cpp
+  RefactoringActionRulesTest.cpp
   RefactoringCallbacksTest.cpp
   RefactoringTest.cpp
   ReplacementsYamlTest.cpp
Index: tools/diagtool/DiagnosticNames.cpp
===================================================================
--- tools/diagtool/DiagnosticNames.cpp
+++ tools/diagtool/DiagnosticNames.cpp
@@ -41,6 +41,7 @@
 #include "clang/Basic/DiagnosticCommentKinds.inc"
 #include "clang/Basic/DiagnosticSemaKinds.inc"
 #include "clang/Basic/DiagnosticAnalysisKinds.inc"
+#include "clang/Basic/DiagnosticRefactoringKinds.inc"
 #undef DIAG
 };
 
Index: lib/Basic/DiagnosticIDs.cpp
===================================================================
--- lib/Basic/DiagnosticIDs.cpp
+++ lib/Basic/DiagnosticIDs.cpp
@@ -43,7 +43,7 @@
   unsigned SFINAE : 2;
   unsigned WarnNoWerror : 1;
   unsigned WarnShowInSystemHeader : 1;
-  unsigned Category : 5;
+  unsigned Category : 6;
 
   uint16_t OptionGroupIndex;
 
@@ -88,6 +88,7 @@
 #include "clang/Basic/DiagnosticCommentKinds.inc"
 #include "clang/Basic/DiagnosticSemaKinds.inc"
 #include "clang/Basic/DiagnosticAnalysisKinds.inc"
+#include "clang/Basic/DiagnosticRefactoringKinds.inc"
 #undef DIAG
 };
 
@@ -137,6 +138,7 @@
 CATEGORY(COMMENT, AST)
 CATEGORY(SEMA, COMMENT)
 CATEGORY(ANALYSIS, SEMA)
+CATEGORY(REFACTORING, ANALYSIS)
 #undef CATEGORY
 
   // Avoid out of bounds reads.
Index: include/clang/Tooling/Refactoring/SourceSelectionConstraints.h
===================================================================
--- /dev/null
+++ include/clang/Tooling/Refactoring/SourceSelectionConstraints.h
@@ -0,0 +1,102 @@
+//===--- SourceSelectionConstraints.h - Clang refactoring library ---------===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_TOOLING_REFACTOR_SOURCE_SELECTION_CONSTRAINTS_H
+#define LLVM_CLANG_TOOLING_REFACTOR_SOURCE_SELECTION_CONSTRAINTS_H
+
+#include "clang/Basic/SourceLocation.h"
+#include <type_traits>
+
+namespace clang {
+namespace tooling {
+namespace selection {
+
+/// This constraint is satisfied when any portion of the source text is
+/// selected. It can be used be used to obtain the raw source selection range.
+struct SourceSelectionRange {
+  SourceSelectionRange(const SourceManager &SM, SourceRange Range)
+      : SM(SM), Range(Range) {}
+
+  const SourceManager &getSources() const { return SM; }
+  SourceRange getRange() const { return Range; }
+
+private:
+  const SourceManager &SM;
+  SourceRange Range;
+};
+
+/// A custom selection requirement.
+class Requirement {
+  /// Subclasses must implement 'T evaluateSelection(SelectionConstraint) const'
+  /// member function. \c T is used to determine the return type that is
+  /// passed to the refactoring rule's function.
+  /// If T is \c DiagnosticOr<S> , then \c S is passed to the rule's function
+  /// using move semantics.
+  /// Otherwise, T is passed to the function directly using move semantics.
+  ///
+  /// The different return type rules allow refactoring actions to fail
+  /// initiation when the relevant portions of AST aren't selected.
+};
+
+namespace traits {
+
+/// A type trait that returns true iff the given type is a valid selection
+/// constraint.
+template <typename T> struct IsConstraint : public std::false_type {};
+
+} // end namespace traits
+
+namespace detail {
+
+template <typename T> struct EvaluateSelectionChecker : std::false_type {};
+
+template <typename T, typename R, typename A>
+struct EvaluateSelectionChecker<R (T::*)(A) const> : std::true_type {
+  using ReturnType = R;
+  using ArgType = A;
+};
+
+template <typename T> class Identity : public Requirement {
+public:
+  T evaluateSelection(T Value) const { return std::move(Value); }
+};
+
+} // end namespace detail
+
+/// A identity function that returns the given selection constraint is provided
+/// for convenience, as it can be passed to \c requiredSelection directly.
+template <typename T> detail::Identity<T> identity() {
+  static_assert(
+      traits::IsConstraint<T>::value,
+      "selection::identity can be used with selection constraints only");
+  return detail::Identity<T>();
+}
+
+namespace traits {
+
+template <>
+struct IsConstraint<SourceSelectionRange> : public std::true_type {};
+
+/// A type trait that returns true iff \c T is a valid selection requirement.
+template <typename T>
+struct IsRequirement
+    : std::conditional<
+          std::is_base_of<Requirement, T>::value &&
+              detail::EvaluateSelectionChecker<decltype(
+                  &T::evaluateSelection)>::value &&
+              IsConstraint<typename detail::EvaluateSelectionChecker<decltype(
+                  &T::evaluateSelection)>::ArgType>::value,
+          std::true_type, std::false_type>::type {};
+
+} // end namespace traits
+} // end namespace selection
+} // end namespace tooling
+} // end namespace clang
+
+#endif // LLVM_CLANG_TOOLING_REFACTOR_SOURCE_SELECTION_CONSTRAINTS_H
Index: include/clang/Tooling/Refactoring/RefactoringResult.h
===================================================================
--- /dev/null
+++ include/clang/Tooling/Refactoring/RefactoringResult.h
@@ -0,0 +1,49 @@
+//===--- RefactoringResult.h - Clang refactoring library ------------------===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_TOOLING_REFACTOR_REFACTORING_RESULT_H
+#define LLVM_CLANG_TOOLING_REFACTOR_REFACTORING_RESULT_H
+
+#include "clang/Tooling/Refactoring/AtomicChange.h"
+
+namespace clang {
+namespace tooling {
+
+/// Refactoring result is a variant that stores a set of source changes or
+/// a set of found symbol occurrences.
+struct RefactoringResult {
+  enum ResultKind {
+    /// A set of source replacements represented using a vector of
+    /// \c AtomicChanges.
+    AtomicChanges
+  };
+
+  RefactoringResult(AtomicChange Change) : Kind(AtomicChanges) {
+    Changes.push_back(std::move(Change));
+  }
+  RefactoringResult(RefactoringResult &&Other) = default;
+  RefactoringResult &operator=(RefactoringResult &&Other) = default;
+
+  ResultKind getKind() const { return Kind; }
+
+  llvm::MutableArrayRef<AtomicChange> getChanges() {
+    assert(getKind() == AtomicChanges &&
+           "Refactoring didn't produce atomic changes");
+    return Changes;
+  }
+
+private:
+  ResultKind Kind;
+  std::vector<AtomicChange> Changes;
+};
+
+} // end namespace tooling
+} // end namespace clang
+
+#endif // LLVM_CLANG_TOOLING_REFACTOR_REFACTORING_RESULT_H
Index: include/clang/Tooling/Refactoring/RefactoringOperationController.h
===================================================================
--- /dev/null
+++ include/clang/Tooling/Refactoring/RefactoringOperationController.h
@@ -0,0 +1,48 @@
+//===--- RefactoringOperationController.h - Clang refactoring library -----===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_TOOLING_REFACTOR_REFACTORING_OPERATION_CONTROLLER_H
+#define LLVM_CLANG_TOOLING_REFACTOR_REFACTORING_OPERATION_CONTROLLER_H
+
+#include "clang/Basic/PartialDiagnostic.h"
+#include "clang/Basic/SourceManager.h"
+
+namespace clang {
+namespace tooling {
+
+/// Encapsulates all of the possible state that an individual refactoring
+/// operation might have. Controls the process of initiation of refactoring
+/// operations, by feeding the right information to the functions that
+/// evaluate the refactoring action rule requirements.
+class RefactoringOperationController {
+public:
+  RefactoringOperationController(const SourceManager &SM) : SM(SM) {}
+
+  PartialDiagnostic::StorageAllocator &getDiagnosticStorage() {
+    return DiagnosticStorage;
+  }
+
+  const SourceManager &getSources() const { return SM; }
+
+  /// Returns the current source selection range as set by the
+  /// refactoring engine. Can be invalid.
+  SourceRange getSelectionRange() const { return SelectionRange; }
+
+  void setSelectionRange(SourceRange R) { SelectionRange = R; }
+
+private:
+  const SourceManager &SM;
+  SourceRange SelectionRange;
+  PartialDiagnostic::StorageAllocator DiagnosticStorage;
+};
+
+} // end namespace tooling
+} // end namespace clang
+
+#endif // LLVM_CLANG_TOOLING_REFACTOR_REFACTORING_OPERATION_CONTROLLER_H
Index: include/clang/Tooling/Refactoring/RefactoringDiagnostic.h
===================================================================
--- /dev/null
+++ include/clang/Tooling/Refactoring/RefactoringDiagnostic.h
@@ -0,0 +1,45 @@
+//===--- RefactoringDiagnostic.h - ------------------------------*- C++ -*-===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_TOOLING_REFACTORING_REFACTORINGDIAGNOSTIC_H
+#define LLVM_CLANG_TOOLING_REFACTORING_REFACTORINGDIAGNOSTIC_H
+
+#include "clang/Basic/Diagnostic.h"
+#include "clang/Basic/PartialDiagnostic.h"
+
+namespace clang {
+namespace diag {
+enum {
+#define DIAG(ENUM, FLAGS, DEFAULT_MAPPING, DESC, GROUP, SFINAE, NOWERROR,      \
+             SHOWINSYSHEADER, CATEGORY)                                        \
+  ENUM,
+#define REFACTORINGSTART
+#include "clang/Basic/DiagnosticRefactoringKinds.inc"
+#undef DIAG
+  NUM_BUILTIN_REFACTORING_DIAGNOSTICS
+};
+} // end namespace diag
+
+namespace tooling {
+
+/// Returns a refactoring diagnostic that represents a failure of initiation of
+/// a refactoring action which occurred because the source selection doesn't
+/// contain the required AST nodes.
+inline PartialDiagnosticAt invalidSelectionError() {
+  // Create an empty diagnostic that represents an invalid selection. The
+  // individual engines will interpret it differently - clang-refactor will
+  // convert it to a proper diagnostic, while IDEs can just ignore it.
+  return PartialDiagnosticAt(
+      SourceLocation(), PartialDiagnostic(PartialDiagnostic::NullDiagnostic()));
+}
+
+} // end namespace tooling
+} // end namespace clang
+
+#endif // LLVM_CLANG_TOOLING_REFACTORING_REFACTORINGDIAGNOSTIC_H
Index: include/clang/Tooling/Refactoring/RefactoringActionRules.h
===================================================================
--- /dev/null
+++ include/clang/Tooling/Refactoring/RefactoringActionRules.h
@@ -0,0 +1,274 @@
+//===--- RefactoringActionRules.h - Clang refactoring library -------------===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_TOOLING_REFACTOR_REFACTORING_ACTION_RULES_H
+#define LLVM_CLANG_TOOLING_REFACTOR_REFACTORING_ACTION_RULES_H
+
+#include "clang/Basic/DiagnosticOr.h"
+#include "clang/Basic/LLVM.h"
+#include "clang/Tooling/Refactoring/RefactoringDiagnostic.h"
+#include "clang/Tooling/Refactoring/RefactoringOperationController.h"
+#include "clang/Tooling/Refactoring/RefactoringResult.h"
+#include "clang/Tooling/Refactoring/SourceSelectionConstraints.h"
+#include "llvm/Support/Error.h"
+#include <type_traits>
+
+namespace clang {
+namespace tooling {
+
+/// A common refactoring action rule interface.
+class RefactoringActionRule {
+public:
+  virtual ~RefactoringActionRule() {}
+
+  /// The refactoring engine calls this method when trying to perform a
+  /// refactoring operation.
+  ///
+  /// The specific rule must return a diagnostic when the refactoring action
+  /// couldn't be initiated/performed, or a valid refactoring result when the
+  /// rule's function runs successfully.
+  virtual DiagnosticOr<RefactoringResult>
+  perform(RefactoringOperationController &Controller) = 0;
+};
+
+/// A set of refactoring action rules that should have unique initiation
+/// requirements.
+using RefactoringActionRules =
+    std::vector<std::unique_ptr<RefactoringActionRule>>;
+
+namespace refactoring_action_rules {
+
+namespace detail {
+
+struct RequirementBase {};
+
+} // end namespace detail
+
+namespace traits {
+
+/// A type trait that returns true iff the given type is a valid rule
+/// requirement.
+template <typename First, typename... Rest>
+struct IsRequirement : std::conditional<IsRequirement<First>::value &&
+                                            IsRequirement<Rest...>::value,
+                                        std::true_type, std::false_type>::type {
+};
+
+template <typename T>
+struct IsRequirement<T>
+    : std::conditional<std::is_base_of<detail::RequirementBase, T>::value,
+                       std::true_type, std::false_type>::type {};
+
+} // end namespace traits
+
+namespace detail {
+
+/// Defines a type alias of type \T when given \c DiagnosticOr<T>, or
+/// \c T otherwise.
+template <typename T> struct DropDiagnosticOr { using Type = T; };
+
+template <typename T> struct DropDiagnosticOr<DiagnosticOr<T>> {
+  using Type = T;
+};
+
+/// The \c requiredSelection refactoring action requirement is represented
+/// using this type.
+template <typename InputT, typename OutputT, typename RequirementT>
+struct SourceSelectionRequirement
+    : std::enable_if<selection::traits::IsConstraint<InputT>::value &&
+                         selection::traits::IsRequirement<RequirementT>::value,
+                     RequirementBase>::type {
+  using OutputType = typename DropDiagnosticOr<OutputT>::Type;
+
+  SourceSelectionRequirement(const RequirementT &Requirement)
+      : Requirement(Requirement) {}
+
+private:
+  const RequirementT Requirement;
+  friend class BaseSpecializedRule;
+};
+
+/// A wrapper class around \c RefactoringActionRule that defines some helper
+/// methods that are used by the subclasses.
+class BaseSpecializedRule : public RefactoringActionRule {
+protected:
+  /// Evaluates a source selection action rule requirement.
+  template <typename InputT, typename OutputT, typename RequirementT>
+  static DiagnosticOr<typename DropDiagnosticOr<OutputT>::Type>
+  evaluate(RefactoringOperationController &Controller,
+           const SourceSelectionRequirement<InputT, OutputT, RequirementT>
+               &SelectionRequirement) {
+    Optional<InputT> Value = evalSelection<InputT>(Controller);
+    if (!Value)
+      return invalidSelectionError();
+    return std::move(
+        SelectionRequirement.Requirement.evaluateSelection(*Value));
+  }
+
+  /// Returns \c T when given \c DiagnosticOr<T>, or \c T otherwise.
+  template <typename T> static T &&removeDiagOr(DiagnosticOr<T> &&X) {
+    assert(X && "unexpected diagnostic!");
+    return std::move(*X);
+  }
+  template <typename T> static T &&removeDiagOr(T &&X) { return std::move(X); }
+
+  using OptionalDiag = Optional<PartialDiagnosticAt>;
+
+  /// Scans the tuple and returns a \c PartialDiagnosticAt
+  /// from the first invalid \c DiagnosticOr value. Returns \c None if all
+  /// values are valid.
+  template <typename FirstT, typename... RestT>
+  static OptionalDiag findDiag(FirstT &First, RestT &... Rest) {
+    OptionalDiag Result = takeDiagOrNone(First);
+    if (Result)
+      return Result;
+    return findDiag(Rest...);
+  }
+
+private:
+  /// Evaluates a selection constraint.
+  template <typename T>
+  static typename std::enable_if<selection::traits::IsConstraint<T>::value,
+                                 llvm::Optional<T>>::type
+  evalSelection(RefactoringOperationController &Controller);
+
+  static OptionalDiag findDiag() { return OptionalDiag(); }
+  template <typename T> static OptionalDiag takeDiagOrNone(T &) {
+    return OptionalDiag();
+  }
+  template <typename T>
+  static OptionalDiag takeDiagOrNone(DiagnosticOr<T> &Diag) {
+    if (!Diag)
+      return std::move(Diag.getDiagnostic());
+    return OptionalDiag();
+  }
+};
+
+/// Evaluates the \c selection::SourceSelectionRange constraint.
+template <>
+llvm::Optional<selection::SourceSelectionRange> inline BaseSpecializedRule::
+    evalSelection<selection::SourceSelectionRange>(
+        RefactoringOperationController &Controller) {
+  SourceRange R = Controller.getSelectionRange();
+  if (R.isInvalid())
+    return None;
+  return selection::SourceSelectionRange(Controller.getSources(), R);
+}
+
+/// A specialized refactoring action rule that calls the stored function once
+/// all the of the requirements are fullfilled. The values produced during the
+/// evaluation of requirements are passed to the stored function.
+template <typename FunctionType, typename... RequirementTypes>
+class PlainFunctionRule final : public BaseSpecializedRule {
+public:
+  PlainFunctionRule(FunctionType Function,
+                    std::tuple<RequirementTypes...> &&Requirements)
+      : Function(Function), Requirements(std::move(Requirements)) {}
+
+  DiagnosticOr<RefactoringResult>
+  perform(RefactoringOperationController &Controller) override {
+    return performImpl(Controller,
+                       llvm::index_sequence_for<RequirementTypes...>());
+  }
+
+private:
+  template <size_t... Is>
+  DiagnosticOr<RefactoringResult>
+  performImpl(RefactoringOperationController &Controller,
+              llvm::index_sequence<Is...>) {
+    // Initiate the operation.
+    auto Values =
+        std::make_tuple(evaluate(Controller, std::get<Is>(Requirements))...);
+    OptionalDiag InitiationFailure = findDiag(std::get<Is>(Values)...);
+    if (InitiationFailure)
+      return std::move(*InitiationFailure);
+    // Perform the operation.
+    Expected<RefactoringResult> Result =
+        Function(removeDiagOr(std::move(std::get<Is>(Values)))...);
+    if (Result)
+      return std::move(*Result);
+    std::string Error = llvm::toString(Result.takeError());
+    return PartialDiagnosticAt(
+        SourceLocation(),
+        PartialDiagnostic(diag::err_refactor_rule_function_failed,
+                          Controller.getDiagnosticStorage())
+            << Error);
+  }
+
+  FunctionType Function;
+  std::tuple<RequirementTypes...> Requirements;
+};
+
+} // end namespace detail
+
+/// Creates a new refactoring action rule that invokes the given function once
+/// all of the requirements are satisfied. The values produced during the
+/// evaluation of requirements are passed to the given function (in the order of
+/// requirements).
+///
+/// \param RefactoringFunction the function that will perform the refactoring
+/// once the requirements are satisfied.
+///
+/// \param Requirements a set of rule requirements that have to be satisfied.
+/// Each requirement must be a valid requirement, i.e. the value of
+/// \c traits::IsRequirement<T> must be true. The following requirements are
+/// currently supported:
+///
+///  - requiredSelection: The refactoring function won't be invoked unless the
+///                       given selection requirement is satisfied.
+template <typename... RequirementTypes>
+std::unique_ptr<RefactoringActionRule>
+apply(Expected<RefactoringResult> (*RefactoringFunction)(
+          typename RequirementTypes::OutputType...),
+      const RequirementTypes &... Requirements) {
+  static_assert(traits::IsRequirement<RequirementTypes...>::value,
+                "invalid refactoring action rule requirement");
+  return llvm::make_unique<detail::PlainFunctionRule<
+      decltype(RefactoringFunction), RequirementTypes...>>(
+      RefactoringFunction, std::make_tuple(Requirements...));
+}
+
+/// Creates a selection requirement from the given requirement.
+///
+/// Requirements must subclass \c selection::Requirement and implement
+/// evaluateSelection member function.
+template <typename T>
+detail::SourceSelectionRequirement<
+    typename selection::detail::EvaluateSelectionChecker<
+        decltype(&T::evaluateSelection)>::ArgType,
+    typename selection::detail::EvaluateSelectionChecker<
+        decltype(&T::evaluateSelection)>::ReturnType,
+    T>
+requiredSelection(
+    const T &Requirement,
+    typename std::enable_if<selection::traits::IsRequirement<T>::value>::type
+        * = nullptr) {
+  return detail::SourceSelectionRequirement<
+      typename selection::detail::EvaluateSelectionChecker<decltype(
+          &T::evaluateSelection)>::ArgType,
+      typename selection::detail::EvaluateSelectionChecker<decltype(
+          &T::evaluateSelection)>::ReturnType,
+      T>(Requirement);
+}
+
+template <typename T>
+void requiredSelection(
+    const T &,
+    typename std::enable_if<
+        !std::is_base_of<selection::Requirement, T>::value>::type * = nullptr) {
+  static_assert(
+      sizeof(T) && false,
+      "selection requirement must be a class derived from Requirement");
+}
+
+} // end namespace refactoring_action_rules
+} // end namespace tooling
+} // end namespace clang
+
+#endif // LLVM_CLANG_TOOLING_REFACTOR_REFACTORING_ACTION_RULES_H
Index: include/clang/Tooling/Refactoring/AtomicChange.h
===================================================================
--- include/clang/Tooling/Refactoring/AtomicChange.h
+++ include/clang/Tooling/Refactoring/AtomicChange.h
@@ -46,6 +46,12 @@
   AtomicChange(llvm::StringRef FilePath, llvm::StringRef Key)
       : Key(Key), FilePath(FilePath) {}
 
+  AtomicChange(AtomicChange &&) = default;
+  AtomicChange(const AtomicChange &) = default;
+
+  AtomicChange &operator=(AtomicChange &&) = default;
+  AtomicChange &operator=(const AtomicChange &) = default;
+
   /// \brief Returns the atomic change as a YAML string.
   std::string toYAMLString();
 
Index: include/clang/Basic/LLVM.h
===================================================================
--- include/clang/Basic/LLVM.h
+++ include/clang/Basic/LLVM.h
@@ -35,6 +35,7 @@
   template<typename T, unsigned N> class SmallVector;
   template<typename T> class SmallVectorImpl;
   template<typename T> class Optional;
+  template <class T> class Expected;
 
   template<typename T>
   struct SaveAndRestore;
@@ -71,6 +72,9 @@
   using llvm::SmallVectorImpl;
   using llvm::SaveAndRestore;
 
+  // Error handling.
+  using llvm::Expected;
+
   // Reference counting.
   using llvm::IntrusiveRefCntPtr;
   using llvm::IntrusiveRefCntPtrInfo;
Index: include/clang/Basic/DiagnosticRefactoringKinds.td
===================================================================
--- /dev/null
+++ include/clang/Basic/DiagnosticRefactoringKinds.td
@@ -0,0 +1,22 @@
+//==--- DiagnosticRefactoringKinds.td - refactoring diagnostics -----------===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// Refactoring Diagnostics
+//===----------------------------------------------------------------------===//
+
+let Component = "Refactoring" in {
+
+let CategoryName = "Refactoring Invocation Issue" in {
+
+def err_refactor_rule_function_failed : Error<"%0">;
+
+}
+
+} // end of Refactoring diagnostics
Index: include/clang/Basic/DiagnosticOr.h
===================================================================
--- /dev/null
+++ include/clang/Basic/DiagnosticOr.h
@@ -0,0 +1,132 @@
+//===--- DiagnosticOr.h - Diagnostic "closures" -----------------*- C++ -*-===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// \brief Implements a partial diagnostic that can be emitted anwyhere
+/// in a DiagnosticBuilder stream.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_BASIC_DIAGNOSTIC_OR_H
+#define LLVM_CLANG_BASIC_DIAGNOSTIC_OR_H
+
+#include "clang/Basic/PartialDiagnostic.h"
+#include "llvm/Support/AlignOf.h"
+#include "llvm/Support/Compiler.h"
+
+namespace clang {
+
+/// Tagged union holding either a T or a PartialDiagnosticAt.
+///
+/// This class parallels llvm::Expected, but replaces Error with
+/// PartialDiagnosticAt.
+template <class T> class LLVM_NODISCARD DiagnosticOr {
+private:
+  using storage_type = T;
+  using reference = T &;
+  using const_reference = const T &;
+  using pointer = T *;
+  using const_pointer = const T *;
+
+public:
+  /// Create an DiagnosticOr<T> diagnostic value from the given partial
+  /// diagnostic.
+  DiagnosticOr(PartialDiagnosticAt Diagnostic) : HasDiagnostic(true) {
+    new (getDiagnosticStorage()) PartialDiagnosticAt(std::move(Diagnostic));
+  }
+
+  /// Create an DiagnosticOr<T> success value from the given OtherT value, which
+  /// must be convertible to T.
+  template <typename OtherT>
+  DiagnosticOr(
+      OtherT &&Val,
+      typename std::enable_if<std::is_convertible<OtherT, T>::value>::type * =
+          nullptr)
+      : HasDiagnostic(false) {
+    new (getStorage()) storage_type(std::forward<OtherT>(Val));
+  }
+
+  DiagnosticOr(DiagnosticOr<T> &&Other) { moveConstruct(std::move(Other)); }
+
+  DiagnosticOr<T> &operator=(DiagnosticOr<T> &&Other) {
+    moveAssign(std::move(Other));
+    return *this;
+  }
+
+  ~DiagnosticOr() {
+    if (!HasDiagnostic)
+      getStorage()->~storage_type();
+    else
+      getDiagnosticStorage()->~PartialDiagnosticAt();
+  }
+
+  /// Returns false if there is a diagnostic.
+  explicit operator bool() { return !HasDiagnostic; }
+
+  PartialDiagnosticAt &getDiagnostic() { return *getDiagnosticStorage(); }
+
+  const PartialDiagnosticAt &getDiagnostic() const {
+    return *getDiagnosticStorage();
+  }
+
+  pointer operator->() { return getStorage(); }
+
+  const_pointer operator->() const { return getStorage(); }
+
+  reference operator*() { return *getStorage(); }
+
+  const_reference operator*() const { return *getStorage(); }
+
+private:
+  void moveConstruct(DiagnosticOr<T> &&Other) {
+    HasDiagnostic = Other.HasDiagnostic;
+
+    if (!HasDiagnostic)
+      new (getStorage()) storage_type(std::move(*Other.getStorage()));
+    else
+      new (getDiagnosticStorage())
+          PartialDiagnosticAt(std::move(*Other.getDiagnosticStorage()));
+  }
+
+  void moveAssign(DiagnosticOr<T> &&Other) {
+    this->~DiagnosticOr();
+    new (this) DiagnosticOr<T>(std::move(Other));
+  }
+
+  storage_type *getStorage() {
+    assert(!HasDiagnostic && "Cannot get value when a diagnostic exists!");
+    return reinterpret_cast<storage_type *>(TStorage.buffer);
+  }
+
+  const storage_type *getStorage() const {
+    assert(!HasDiagnostic && "Cannot get value when a diagnostic exists!");
+    return reinterpret_cast<const storage_type *>(TStorage.buffer);
+  }
+
+  PartialDiagnosticAt *getDiagnosticStorage() {
+    assert(HasDiagnostic && "Cannot get diagnostic when a value exists!");
+    return reinterpret_cast<PartialDiagnosticAt *>(DiagnosticStorage.buffer);
+  }
+
+  const PartialDiagnosticAt *getDiagnosticStorage() const {
+    assert(HasDiagnostic && "Cannot get diagnostic when a value exists!");
+    return reinterpret_cast<const PartialDiagnosticAt *>(
+        DiagnosticStorage.buffer);
+  }
+
+  union {
+    llvm::AlignedCharArrayUnion<storage_type> TStorage;
+    llvm::AlignedCharArrayUnion<PartialDiagnosticAt> DiagnosticStorage;
+  };
+  bool HasDiagnostic : 1;
+};
+
+} // end namespace clang
+
+#endif // LLVM_CLANG_BASIC_DIAGNOSTIC_OR_H
Index: include/clang/Basic/DiagnosticIDs.h
===================================================================
--- include/clang/Basic/DiagnosticIDs.h
+++ include/clang/Basic/DiagnosticIDs.h
@@ -38,7 +38,8 @@
       DIAG_START_COMMENT       = DIAG_START_AST             +  110,
       DIAG_START_SEMA          = DIAG_START_COMMENT         +  100,
       DIAG_START_ANALYSIS      = DIAG_START_SEMA            + 3500,
-      DIAG_UPPER_LIMIT         = DIAG_START_ANALYSIS        +  100
+      DIAG_START_REFACTORING   = DIAG_START_ANALYSIS        +  100,
+      DIAG_UPPER_LIMIT         = DIAG_START_REFACTORING     +  500
     };
 
     class CustomDiagInfo;
Index: include/clang/Basic/Diagnostic.td
===================================================================
--- include/clang/Basic/Diagnostic.td
+++ include/clang/Basic/Diagnostic.td
@@ -137,6 +137,7 @@
 include "DiagnosticFrontendKinds.td"
 include "DiagnosticLexKinds.td"
 include "DiagnosticParseKinds.td"
+include "DiagnosticRefactoringKinds.td"
 include "DiagnosticSemaKinds.td"
 include "DiagnosticSerializationKinds.td"
 
Index: include/clang/Basic/CMakeLists.txt
===================================================================
--- include/clang/Basic/CMakeLists.txt
+++ include/clang/Basic/CMakeLists.txt
@@ -13,6 +13,7 @@
 clang_diag_gen(Frontend)
 clang_diag_gen(Lex)
 clang_diag_gen(Parse)
+clang_diag_gen(Refactoring)
 clang_diag_gen(Sema)
 clang_diag_gen(Serialization)
 clang_tablegen(DiagnosticGroups.inc -gen-clang-diag-groups
Index: include/clang/Basic/AllDiagnostics.h
===================================================================
--- include/clang/Basic/AllDiagnostics.h
+++ include/clang/Basic/AllDiagnostics.h
@@ -24,6 +24,7 @@
 #include "clang/Parse/ParseDiagnostic.h"
 #include "clang/Sema/SemaDiagnostic.h"
 #include "clang/Serialization/SerializationDiagnostic.h"
+#include "clang/Tooling/Refactoring/RefactoringDiagnostic.h"
 
 namespace clang {
 template <size_t SizeOfStr, typename FieldType>
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to