zhaomo updated this revision to Diff 348113.

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D103195/new/

https://reviews.llvm.org/D103195

Files:
  clang/include/clang/ASTMatchers/GtestMatchers.h
  clang/lib/ASTMatchers/GtestMatchers.cpp
  clang/unittests/ASTMatchers/GtestMatchersTest.cpp

Index: clang/unittests/ASTMatchers/GtestMatchersTest.cpp
===================================================================
--- clang/unittests/ASTMatchers/GtestMatchersTest.cpp
+++ clang/unittests/ASTMatchers/GtestMatchersTest.cpp
@@ -42,6 +42,14 @@
 #define EXPECT_PRED_FORMAT2(pred_format, v1, v2) \
     GTEST_PRED_FORMAT2_(pred_format, v1, v2, GTEST_NONFATAL_FAILURE_)
 
+#define GTEST_PRED_FORMAT1_(pred_format, v1, on_failure) \
+  GTEST_ASSERT_(pred_format(#v1, v1), on_failure)
+
+#define EXPECT_PRED_FORMAT1(pred_format, v1) \
+  GTEST_PRED_FORMAT1_(pred_format, v1, GTEST_NONFATAL_FAILURE_)
+#define ASSERT_PRED_FORMAT1(pred_format, v1) \
+  GTEST_PRED_FORMAT1_(pred_format, v1, GTEST_FATAL_FAILURE_)
+
 #define EXPECT_EQ(val1, val2) \
     EXPECT_PRED_FORMAT2(::testing::internal::EqHelper::Compare, val1, val2)
 #define EXPECT_NE(val1, val2) \
@@ -55,11 +63,29 @@
 #define EXPECT_LT(val1, val2) \
     EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperLT, val1, val2)
 
+#define ASSERT_THAT(value, matcher) \
+  ASSERT_PRED_FORMAT1(              \
+      ::testing::internal::MakePredicateFormatterFromMatcher(matcher), value)
+#define EXPECT_THAT(value, matcher) \
+  EXPECT_PRED_FORMAT1(              \
+      ::testing::internal::MakePredicateFormatterFromMatcher(matcher), value)
+
 #define ASSERT_EQ(val1, val2) \
     ASSERT_PRED_FORMAT2(::testing::internal::EqHelper::Compare, val1, val2)
 #define ASSERT_NE(val1, val2) \
     ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperNE, val1, val2)
 
+#define GMOCK_ON_CALL_IMPL_(mock_expr, Setter, call)                    \
+  ((mock_expr).gmock_##call)(::testing::internal::GetWithoutMatchers(), \
+                             nullptr)                                   \
+      .Setter(nullptr, 0, #mock_expr, #call)
+
+#define ON_CALL(obj, call) \
+  GMOCK_ON_CALL_IMPL_(obj, InternalDefaultActionSetAt, call)
+
+#define EXPECT_CALL(obj, call) \
+  GMOCK_ON_CALL_IMPL_(obj, InternalExpectedAt, call)
+
   namespace testing {
   namespace internal {
   class EqHelper {
@@ -96,8 +122,77 @@
                   const T2& val2) {
     return 0;
   }
+
+  // For implementing ASSERT_THAT() and EXPECT_THAT().  The template
+  // argument M must be a type that can be converted to a matcher.
+  template <typename M>
+  class PredicateFormatterFromMatcher {
+   public:
+    explicit PredicateFormatterFromMatcher(M m) : matcher_(m) {}
+
+    // This template () operator allows a PredicateFormatterFromMatcher
+    // object to act as a predicate-formatter suitable for using with
+    // Google Test's EXPECT_PRED_FORMAT1() macro.
+    template <typename T>
+    int operator()(const char* value_text, const T& x) const {
+      return 0;
+    }
+
+   private:
+    const M matcher_;
+  };
+
+  template <typename M>
+  inline PredicateFormatterFromMatcher<M> MakePredicateFormatterFromMatcher(
+      M matcher) {
+    return PredicateFormatterFromMatcher<M>(matcher);
+  }
+
+  bool GetWithoutMatchers() { return false; }
+
+  template <typename F>
+  class MockSpec {
+   public:
+    MockSpec<F>() {}
+
+    bool InternalDefaultActionSetAt(
+        const char* file, int line, const char* obj, const char* call) {
+      return false;
+    }
+
+    bool InternalExpectedAt(
+        const char* file, int line, const char* obj, const char* call) {
+      return false;
+    }
+
+    MockSpec<F> operator()(bool, void*) {
+      return *this;
+    }
+  };  // class MockSpec
+
   }  // namespace internal
+
+  template <typename T>
+  int StrEq(T val) {
+    return 0;
+  }
+  template <typename T>
+  int Eq(T val) {
+    return 0;
+  }
+
   }  // namespace testing
+
+  class Mock {
+    public:
+    Mock() {}
+    testing::internal::MockSpec<int> gmock_TwoArgsMethod(int, int) {
+      return testing::internal::MockSpec<int>();
+    }
+    testing::internal::MockSpec<int> gmock_TwoArgsMethod(bool, void*) {
+      return testing::internal::MockSpec<int>();
+    }
+  };  // class Mock
 )cc";
 
 static std::string wrapGtest(llvm::StringRef Input) {
@@ -187,5 +282,137 @@
       matches(wrapGtest(Input), gtestExpect(GtestCmp::Gt, expr(), expr())));
 }
 
+TEST(GtestExpectTest, ThatShouldMatchAssertThat) {
+  std::string Input = R"cc(
+    using ::testing::Eq;
+    void Test() { ASSERT_THAT(2, Eq(2)); }
+  )cc";
+  EXPECT_TRUE(matches(
+      wrapGtest(Input),
+      gtestAssertThat(
+          expr(), callExpr(callee(functionDecl(hasName("::testing::Eq")))))));
+}
+
+TEST(GtestExpectTest, ThatShouldMatchExpectThat) {
+  std::string Input = R"cc(
+    using ::testing::Eq;
+    void Test() { EXPECT_THAT(2, Eq(2)); }
+  )cc";
+  EXPECT_TRUE(matches(
+      wrapGtest(Input),
+      gtestExpectThat(
+          expr(), callExpr(callee(functionDecl(hasName("::testing::Eq")))))));
+}
+
+TEST(GtestOnCallTest, CallShouldMatchOnCallWithoutParams1) {
+  std::string Input = R"cc(
+    void Test() {
+      Mock mock;
+      ON_CALL(mock, TwoArgsMethod);
+    }
+  )cc";
+  EXPECT_TRUE(matches(wrapGtest(Input),
+                      gtestOnCall(expr(hasType(cxxRecordDecl(hasName("Mock")))),
+                                  "TwoArgsMethod", MockArgs::NoMatchers)));
+}
+
+TEST(GtestOnCallTest, CallShouldMatchOnCallWithoutParams2) {
+  std::string Input = R"cc(
+    void Test() {
+      Mock mock;
+      ON_CALL(mock, TwoArgsMethod);
+    }
+  )cc";
+  EXPECT_TRUE(matches(
+      wrapGtest(Input),
+      gtestOnCall(cxxMemberCallExpr(
+                      callee(functionDecl(hasName("gmock_TwoArgsMethod"))))
+                      .bind("mock_call"),
+                  MockArgs::NoMatchers)));
+}
+
+TEST(GtestOnCallTest, CallShouldMatchOnCallWithParams1) {
+  std::string Input = R"cc(
+    void Test() {
+      Mock mock;
+      ON_CALL(mock, TwoArgsMethod(1, 2));
+    }
+  )cc";
+  EXPECT_TRUE(matches(wrapGtest(Input),
+                      gtestOnCall(expr(hasType(cxxRecordDecl(hasName("Mock")))),
+                                  "TwoArgsMethod", MockArgs::HasMatchers)));
+}
+
+TEST(GtestOnCallTest, CallShouldMatchOnCallWithParams2) {
+  std::string Input = R"cc(
+    void Test() {
+      Mock mock;
+      ON_CALL(mock, TwoArgsMethod(1, 2));
+    }
+  )cc";
+  EXPECT_TRUE(matches(
+      wrapGtest(Input),
+      gtestOnCall(cxxMemberCallExpr(
+                      callee(functionDecl(hasName("gmock_TwoArgsMethod"))))
+                      .bind("mock_call"),
+                  MockArgs::HasMatchers)));
+}
+
+TEST(GtestExpectCallTest, CallShouldMatchExpectCallWithoutParams1) {
+  std::string Input = R"cc(
+    void Test() {
+      Mock mock;
+      EXPECT_CALL(mock, TwoArgsMethod);
+    }
+  )cc";
+  EXPECT_TRUE(
+      matches(wrapGtest(Input),
+              gtestExpectCall(expr(hasType(cxxRecordDecl(hasName("Mock")))),
+                              "TwoArgsMethod", MockArgs::NoMatchers)));
+}
+
+TEST(GtestExpectCallTest, CallShouldMatchExpectCallWithoutParams2) {
+  std::string Input = R"cc(
+    void Test() {
+      Mock mock;
+      EXPECT_CALL(mock, TwoArgsMethod);
+    }
+  )cc";
+  EXPECT_TRUE(matches(
+      wrapGtest(Input),
+      gtestExpectCall(cxxMemberCallExpr(
+                          callee(functionDecl(hasName("gmock_TwoArgsMethod"))))
+                          .bind("mock_call"),
+                      MockArgs::NoMatchers)));
+}
+
+TEST(GtestExpectCallTest, CallShouldMatchExpectCallWithParams1) {
+  std::string Input = R"cc(
+    void Test() {
+      Mock mock;
+      EXPECT_CALL(mock, TwoArgsMethod(1, 2));
+    }
+  )cc";
+  EXPECT_TRUE(
+      matches(wrapGtest(Input),
+              gtestExpectCall(expr(hasType(cxxRecordDecl(hasName("Mock")))),
+                              "TwoArgsMethod", MockArgs::HasMatchers)));
+}
+
+TEST(GtestExpectCallTest, CallShouldMatchExpectCallWithParams2) {
+  std::string Input = R"cc(
+    void Test() {
+      Mock mock;
+      EXPECT_CALL(mock, TwoArgsMethod(1, 2));
+    }
+  )cc";
+  EXPECT_TRUE(matches(
+      wrapGtest(Input),
+      gtestExpectCall(cxxMemberCallExpr(
+                          callee(functionDecl(hasName("gmock_TwoArgsMethod"))))
+                          .bind("mock_call"),
+                      MockArgs::HasMatchers)));
+}
+
 } // end namespace ast_matchers
 } // end namespace clang
Index: clang/lib/ASTMatchers/GtestMatchers.cpp
===================================================================
--- clang/lib/ASTMatchers/GtestMatchers.cpp
+++ clang/lib/ASTMatchers/GtestMatchers.cpp
@@ -7,74 +7,88 @@
 //===----------------------------------------------------------------------===//
 
 #include "clang/ASTMatchers/GtestMatchers.h"
-#include "clang/ASTMatchers/ASTMatchFinder.h"
 #include "clang/AST/ASTConsumer.h"
 #include "clang/AST/ASTContext.h"
 #include "clang/AST/RecursiveASTVisitor.h"
+#include "clang/ASTMatchers/ASTMatchFinder.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/StringMap.h"
-#include "llvm/Support/Timer.h"
-#include <deque>
-#include <memory>
-#include <set>
+#include "llvm/ADT/StringRef.h"
 
 namespace clang {
 namespace ast_matchers {
+namespace {
+
+enum class MacroType {
+  Expect,
+  Assert,
+  On,
+};
 
 static DeclarationMatcher getComparisonDecl(GtestCmp Cmp) {
   switch (Cmp) {
-    case GtestCmp::Eq:
-      return cxxMethodDecl(hasName("Compare"),
-                           ofClass(cxxRecordDecl(isSameOrDerivedFrom(
-                               hasName("::testing::internal::EqHelper")))));
-    case GtestCmp::Ne:
-      return functionDecl(hasName("::testing::internal::CmpHelperNE"));
-    case GtestCmp::Ge:
-      return functionDecl(hasName("::testing::internal::CmpHelperGE"));
-    case GtestCmp::Gt:
-      return functionDecl(hasName("::testing::internal::CmpHelperGT"));
-    case GtestCmp::Le:
-      return functionDecl(hasName("::testing::internal::CmpHelperLE"));
-    case GtestCmp::Lt:
-      return functionDecl(hasName("::testing::internal::CmpHelperLT"));
+  case GtestCmp::Eq:
+    return cxxMethodDecl(hasName("Compare"),
+                         ofClass(cxxRecordDecl(isSameOrDerivedFrom(
+                             hasName("::testing::internal::EqHelper")))));
+  case GtestCmp::Ne:
+    return functionDecl(hasName("::testing::internal::CmpHelperNE"));
+  case GtestCmp::Ge:
+    return functionDecl(hasName("::testing::internal::CmpHelperGE"));
+  case GtestCmp::Gt:
+    return functionDecl(hasName("::testing::internal::CmpHelperGT"));
+  case GtestCmp::Le:
+    return functionDecl(hasName("::testing::internal::CmpHelperLE"));
+  case GtestCmp::Lt:
+    return functionDecl(hasName("::testing::internal::CmpHelperLT"));
   }
-  llvm_unreachable("Unhandled GtestCmp enum");
 }
 
-static llvm::StringRef getAssertMacro(GtestCmp Cmp) {
-  switch (Cmp) {
-    case GtestCmp::Eq:
-      return "ASSERT_EQ";
-    case GtestCmp::Ne:
-      return "ASSERT_NE";
-    case GtestCmp::Ge:
-      return "ASSERT_GE";
-    case GtestCmp::Gt:
-      return "ASSERT_GT";
-    case GtestCmp::Le:
-      return "ASSERT_LE";
-    case GtestCmp::Lt:
-      return "ASSERT_LT";
+static llvm::StringRef getMacroTypeName(MacroType Macro) {
+  switch (Macro) {
+  case MacroType::Expect:
+    return "EXPECT";
+  case MacroType::Assert:
+    return "ASSERT";
+  case MacroType::On:
+    return "ON";
   }
-  llvm_unreachable("Unhandled GtestCmp enum");
 }
 
-static llvm::StringRef getExpectMacro(GtestCmp Cmp) {
+static llvm::StringRef getComparisonTypeName(GtestCmp Cmp) {
   switch (Cmp) {
-    case GtestCmp::Eq:
-      return "EXPECT_EQ";
-    case GtestCmp::Ne:
-      return "EXPECT_NE";
-    case GtestCmp::Ge:
-      return "EXPECT_GE";
-    case GtestCmp::Gt:
-      return "EXPECT_GT";
-    case GtestCmp::Le:
-      return "EXPECT_LE";
-    case GtestCmp::Lt:
-      return "EXPECT_LT";
+  case GtestCmp::Eq:
+    return "EQ";
+  case GtestCmp::Ne:
+    return "NE";
+  case GtestCmp::Ge:
+    return "GE";
+  case GtestCmp::Gt:
+    return "GT";
+  case GtestCmp::Le:
+    return "LE";
+  case GtestCmp::Lt:
+    return "LT";
+  }
+}
+
+static std::string getMacroName(MacroType Macro, GtestCmp Cmp) {
+  return (getMacroTypeName(Macro) + "_" + getComparisonTypeName(Cmp)).str();
+}
+
+static std::string getMacroName(MacroType Macro, llvm::StringRef Operation) {
+  return (getMacroTypeName(Macro) + "_" + Operation).str();
+}
+
+static llvm::StringRef getActionSpecGeneratorName(MacroType Macro) {
+  switch (Macro) {
+  case MacroType::On:
+    return "InternalDefaultActionSetAt";
+  case MacroType::Expect:
+    return "InternalExpectedAt";
+  default:
+    llvm_unreachable("Unhandled MacroType enum");
   }
-  llvm_unreachable("Unhandled GtestCmp enum");
 }
 
 // In general, AST matchers cannot match calls to macros. However, we can
@@ -86,18 +100,101 @@
 //
 // We use this approach to implement the derived matchers gtestAssert and
 // gtestExpect.
+static internal::BindableMatcher<Stmt>
+gtestComparisonInternal(MacroType Macro, GtestCmp Cmp, StatementMatcher Left,
+                        StatementMatcher Right) {
+  return callExpr(isExpandedFromMacro(getMacroName(Macro, Cmp)),
+                  callee(getComparisonDecl(Cmp)), hasArgument(2, Left),
+                  hasArgument(3, Right));
+}
+
+static internal::BindableMatcher<Stmt>
+gtestThatInternal(MacroType Macro, StatementMatcher Actual,
+                  StatementMatcher Matcher) {
+  return cxxOperatorCallExpr(
+      isExpandedFromMacro(getMacroName(Macro, "THAT")),
+      hasOverloadedOperatorName("()"), hasArgument(2, Actual),
+      hasArgument(
+          0, expr(hasType(classTemplateSpecializationDecl(hasName(
+                      "::testing::internal::PredicateFormatterFromMatcher"))),
+                  ignoringImplicit(
+                      callExpr(callee(functionDecl(hasName(
+                                   "::testing::internal::"
+                                   "MakePredicateFormatterFromMatcher"))),
+                               hasArgument(0, ignoringImplicit(Matcher)))))));
+}
+
+internal::BindableMatcher<Stmt>
+gtestCallInternal(MacroType Macro, StatementMatcher MockCall, MockArgs Args) {
+  switch (Args) {
+  case MockArgs::NoMatchers:
+    return cxxMemberCallExpr(
+        isExpandedFromMacro(getMacroName(Macro, "CALL")),
+        callee(functionDecl(hasName(getActionSpecGeneratorName(Macro)))),
+        onImplicitObjectArgument(ignoringImplicit(MockCall)));
+  case MockArgs::HasMatchers:
+    return cxxMemberCallExpr(
+        isExpandedFromMacro(getMacroName(Macro, "CALL")),
+        callee(functionDecl(hasName(getActionSpecGeneratorName(Macro)))),
+        onImplicitObjectArgument(ignoringImplicit(cxxOperatorCallExpr(
+            hasOverloadedOperatorName("()"), argumentCountIs(3),
+            hasArgument(0, ignoringImplicit(MockCall))))));
+  }
+}
+
+internal::BindableMatcher<Stmt>
+gtestCallInternal(MacroType Macro, StatementMatcher MockObject,
+                  llvm::StringRef MockMethodName, MockArgs Args) {
+  return gtestCallInternal(
+      Macro,
+      cxxMemberCallExpr(
+          onImplicitObjectArgument(MockObject),
+          callee(functionDecl(hasName(("gmock_" + MockMethodName).str())))),
+      Args);
+}
+
+} // namespace
+
 internal::BindableMatcher<Stmt> gtestAssert(GtestCmp Cmp, StatementMatcher Left,
                                             StatementMatcher Right) {
-  return callExpr(callee(getComparisonDecl(Cmp)),
-                  isExpandedFromMacro(getAssertMacro(Cmp).str()),
-                  hasArgument(2, Left), hasArgument(3, Right));
+  return gtestComparisonInternal(MacroType::Assert, Cmp, Left, Right);
 }
 
 internal::BindableMatcher<Stmt> gtestExpect(GtestCmp Cmp, StatementMatcher Left,
                                             StatementMatcher Right) {
-  return callExpr(callee(getComparisonDecl(Cmp)),
-                  isExpandedFromMacro(getExpectMacro(Cmp).str()),
-                  hasArgument(2, Left), hasArgument(3, Right));
+  return gtestComparisonInternal(MacroType::Expect, Cmp, Left, Right);
+}
+
+internal::BindableMatcher<Stmt> gtestAssertThat(StatementMatcher Actual,
+                                                StatementMatcher Matcher) {
+  return gtestThatInternal(MacroType::Assert, Actual, Matcher);
+}
+
+internal::BindableMatcher<Stmt> gtestExpectThat(StatementMatcher Actual,
+                                                StatementMatcher Matcher) {
+  return gtestThatInternal(MacroType::Expect, Actual, Matcher);
+}
+
+internal::BindableMatcher<Stmt> gtestOnCall(StatementMatcher MockObject,
+                                            llvm::StringRef MockMethodName,
+                                            MockArgs Args) {
+  return gtestCallInternal(MacroType::On, MockObject, MockMethodName, Args);
+}
+
+internal::BindableMatcher<Stmt> gtestOnCall(StatementMatcher MockCall,
+                                            MockArgs Args) {
+  return gtestCallInternal(MacroType::On, MockCall, Args);
+}
+
+internal::BindableMatcher<Stmt> gtestExpectCall(StatementMatcher MockObject,
+                                                llvm::StringRef MockMethodName,
+                                                MockArgs Args) {
+  return gtestCallInternal(MacroType::Expect, MockObject, MockMethodName, Args);
+}
+
+internal::BindableMatcher<Stmt> gtestExpectCall(StatementMatcher MockCall,
+                                                MockArgs Args) {
+  return gtestCallInternal(MacroType::Expect, MockCall, Args);
 }
 
 } // end namespace ast_matchers
Index: clang/include/clang/ASTMatchers/GtestMatchers.h
===================================================================
--- clang/include/clang/ASTMatchers/GtestMatchers.h
+++ clang/include/clang/ASTMatchers/GtestMatchers.h
@@ -16,6 +16,7 @@
 
 #include "clang/AST/Stmt.h"
 #include "clang/ASTMatchers/ASTMatchers.h"
+#include "llvm/ADT/StringRef.h"
 
 namespace clang {
 namespace ast_matchers {
@@ -30,6 +31,11 @@
   Lt,
 };
 
+enum class MockArgs {
+  NoMatchers,
+  HasMatchers,
+};
+
 /// Matcher for gtest's ASSERT_... macros.
 internal::BindableMatcher<Stmt> gtestAssert(GtestCmp Cmp, StatementMatcher Left,
                                             StatementMatcher Right);
@@ -38,6 +44,44 @@
 internal::BindableMatcher<Stmt> gtestExpect(GtestCmp Cmp, StatementMatcher Left,
                                             StatementMatcher Right);
 
+/// Matcher for gtest's ASSERT_THAT macro.
+internal::BindableMatcher<Stmt> gtestAssertThat(StatementMatcher Actual,
+                                                StatementMatcher Matcher);
+
+/// Matcher for gtest's EXPECT_THAT macro.
+internal::BindableMatcher<Stmt> gtestExpectThat(StatementMatcher Actual,
+                                                StatementMatcher Matcher);
+
+/// Matcher for gtest's `ON_CALL` macro. When `Args` is `NoMatchers`,
+/// this matches a mock call to a method without argument matchers e.g.
+/// `ON_CALL(mock, TwoParamMethod)`; when `Args` is `HasMatchers`, this
+/// matches a mock call to a method with argument matchers e.g.
+/// `ON_CALL(mock, TwoParamMethod(m1, m2))`. `MockObject` matches the mock
+/// object and `MockMethodName` is the name of the method invoked on the mock
+/// object.
+internal::BindableMatcher<Stmt> gtestOnCall(StatementMatcher MockObject,
+                                            llvm::StringRef MockMethodName,
+                                            MockArgs Args);
+
+/// Matcher for gtest's `ON_CALL` macro. When `Args` is `NoMatchers`,
+/// this matches a mock call to a method without argument matchers e.g.
+/// `ON_CALL(mock, TwoParamMethod)`; when `Args` is `HasMatchers`, this
+/// matches a mock call to a method with argument matchers e.g.
+/// `ON_CALL(mock, TwoParamMethod(m1, m2))`. `MockCall` matches the whole mock
+/// member method call. This API is more flexible but requires more knowledge of
+/// the AST structure of ON_CALL macros.
+internal::BindableMatcher<Stmt> gtestOnCall(StatementMatcher MockCall,
+                                            MockArgs Args);
+
+/// Like the first `gtestOnCall` overload but for `EXPECT_CALL`.
+internal::BindableMatcher<Stmt> gtestExpectCall(StatementMatcher MockObject,
+                                                llvm::StringRef MockMethodName,
+                                                MockArgs Args);
+
+/// Like the second `gtestOnCall` overload but for `EXPECT_CALL`.
+internal::BindableMatcher<Stmt> gtestExpectCall(StatementMatcher MockCall,
+                                                MockArgs Args);
+
 } // namespace ast_matchers
 } // namespace clang
 
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to