This revision was automatically updated to reflect the committed changes.
Closed by commit rGc0bcd11068fc: [ASTImporter] Add basic support for comparing 
Stmts and compare function bodies (authored by teemperor).
Herald added a project: clang.
Herald added a subscriber: cfe-commits.

  rG LLVM Github Monorepo



Index: clang/unittests/AST/StructuralEquivalenceTest.cpp
--- clang/unittests/AST/StructuralEquivalenceTest.cpp
+++ clang/unittests/AST/StructuralEquivalenceTest.cpp
@@ -19,14 +19,10 @@
   std::unique_ptr<ASTUnit> AST0, AST1;
   std::string Code0, Code1; // Buffers for SourceManager
-  // Get a pair of node pointers into the synthesized AST from the given code
-  // snippets. To determine the returned node, a separate matcher is specified
-  // for both snippets. The first matching node is returned.
-  template <typename NodeType, typename MatcherType>
-  std::tuple<NodeType *, NodeType *>
-  makeDecls(const std::string &SrcCode0, const std::string &SrcCode1,
-            TestLanguage Lang, const MatcherType &Matcher0,
-            const MatcherType &Matcher1) {
+  // Parses the source code in the specified language and sets the ASTs of
+  // the current test instance to the parse result.
+  void makeASTUnits(const std::string &SrcCode0, const std::string &SrcCode1,
+                    TestLanguage Lang) {
     this->Code0 = SrcCode0;
     this->Code1 = SrcCode1;
     std::vector<std::string> Args = getCommandLineArgsForTesting(Lang);
@@ -35,6 +31,17 @@
     AST0 = tooling::buildASTFromCodeWithArgs(Code0, Args, InputFileName);
     AST1 = tooling::buildASTFromCodeWithArgs(Code1, Args, InputFileName);
+  }
+  // Get a pair of node pointers into the synthesized AST from the given code
+  // snippets. To determine the returned node, a separate matcher is specified
+  // for both snippets. The first matching node is returned.
+  template <typename NodeType, typename MatcherType>
+  std::tuple<NodeType *, NodeType *>
+  makeDecls(const std::string &SrcCode0, const std::string &SrcCode1,
+            TestLanguage Lang, const MatcherType &Matcher0,
+            const MatcherType &Matcher1) {
+    makeASTUnits(SrcCode0, SrcCode1, Lang);
     NodeType *D0 = FirstDeclMatcher<NodeType>().match(
         AST0->getASTContext().getTranslationUnitDecl(), Matcher0);
@@ -47,14 +54,7 @@
   std::tuple<TranslationUnitDecl *, TranslationUnitDecl *>
   makeTuDecls(const std::string &SrcCode0, const std::string &SrcCode1,
               TestLanguage Lang) {
-    this->Code0 = SrcCode0;
-    this->Code1 = SrcCode1;
-    std::vector<std::string> Args = getCommandLineArgsForTesting(Lang);
-    const char *const InputFileName = "";
-    AST0 = tooling::buildASTFromCodeWithArgs(Code0, Args, InputFileName);
-    AST1 = tooling::buildASTFromCodeWithArgs(Code1, Args, InputFileName);
+    makeASTUnits(SrcCode0, SrcCode1, Lang);
     return std::make_tuple(AST0->getASTContext().getTranslationUnitDecl(),
@@ -80,6 +80,56 @@
     return makeDecls<NamedDecl>(SrcCode0, SrcCode1, Lang, Matcher);
+  // Wraps a Stmt and the ASTContext that contains it.
+  struct StmtWithASTContext {
+    Stmt *S;
+    ASTContext *Context;
+    explicit StmtWithASTContext(Stmt &S, ASTContext &Context)
+        : S(&S), Context(&Context) {}
+    explicit StmtWithASTContext(FunctionDecl *FD)
+        : S(FD->getBody()), Context(&FD->getASTContext()) {}
+  };
+  // Get a pair of node pointers into the synthesized AST from the given code
+  // snippets. To determine the returned node, a separate matcher is specified
+  // for both snippets. The first matching node is returned.
+  template <typename MatcherType>
+  std::tuple<StmtWithASTContext, StmtWithASTContext>
+  makeStmts(const std::string &SrcCode0, const std::string &SrcCode1,
+            TestLanguage Lang, const MatcherType &Matcher0,
+            const MatcherType &Matcher1) {
+    makeASTUnits(SrcCode0, SrcCode1, Lang);
+    Stmt *S0 = FirstDeclMatcher<Stmt>().match(
+        AST0->getASTContext().getTranslationUnitDecl(), Matcher0);
+    Stmt *S1 = FirstDeclMatcher<Stmt>().match(
+        AST1->getASTContext().getTranslationUnitDecl(), Matcher1);
+    return std::make_tuple(StmtWithASTContext(*S0, AST0->getASTContext()),
+                           StmtWithASTContext(*S1, AST1->getASTContext()));
+  }
+  // Get a pair of node pointers into the synthesized AST from the given code
+  // snippets. The same matcher is used for both snippets.
+  template <typename MatcherType>
+  std::tuple<StmtWithASTContext, StmtWithASTContext>
+  makeStmts(const std::string &SrcCode0, const std::string &SrcCode1,
+            TestLanguage Lang, const MatcherType &AMatcher) {
+    return makeStmts(SrcCode0, SrcCode1, Lang, AMatcher, AMatcher);
+  }
+  // Convenience function for makeStmts that wraps the code inside a function
+  // body.
+  template <typename MatcherType>
+  std::tuple<StmtWithASTContext, StmtWithASTContext>
+  makeWrappedStmts(const std::string &SrcCode0, const std::string &SrcCode1,
+                   TestLanguage Lang, const MatcherType &AMatcher) {
+    auto Wrap = [](const std::string &Src) {
+      return "void wrapped() {" + Src + ";}";
+    };
+    return makeStmts(Wrap(SrcCode0), Wrap(SrcCode1), Lang, AMatcher);
+  }
   bool testStructuralMatch(Decl *D0, Decl *D1) {
     llvm::DenseSet<std::pair<Decl *, Decl *>> NonEquivalentDecls01;
     llvm::DenseSet<std::pair<Decl *, Decl *>> NonEquivalentDecls10;
@@ -95,6 +145,26 @@
     return Eq01;
+  bool testStructuralMatch(StmtWithASTContext S0, StmtWithASTContext S1) {
+    llvm::DenseSet<std::pair<Decl *, Decl *>> NonEquivalentDecls01;
+    llvm::DenseSet<std::pair<Decl *, Decl *>> NonEquivalentDecls10;
+    StructuralEquivalenceContext Ctx01(
+        *S0.Context, *S1.Context, NonEquivalentDecls01,
+        StructuralEquivalenceKind::Default, false, false);
+    StructuralEquivalenceContext Ctx10(
+        *S1.Context, *S0.Context, NonEquivalentDecls10,
+        StructuralEquivalenceKind::Default, false, false);
+    bool Eq01 = Ctx01.IsEquivalent(S0.S, S1.S);
+    bool Eq10 = Ctx10.IsEquivalent(S1.S, S0.S);
+    EXPECT_EQ(Eq01, Eq10);
+    return Eq01;
+  }
+  bool
+  testStructuralMatch(std::tuple<StmtWithASTContext, StmtWithASTContext> t) {
+    return testStructuralMatch(get<0>(t), get<1>(t));
+  }
   bool testStructuralMatch(std::tuple<Decl *, Decl *> t) {
     return testStructuralMatch(get<0>(t), get<1>(t));
@@ -1375,5 +1445,225 @@
       findDeclPair<FunctionDecl>(TU, functionDecl(hasName("x")))));
+struct StructuralEquivalenceStmtTest : StructuralEquivalenceTest {};
+/// Fallback matcher to be used only when there is no specific matcher for a
+/// Expr subclass. Remove this once all Expr subclasses have their own matcher.
+static auto &fallbackExprMatcher = expr;
+TEST_F(StructuralEquivalenceStmtTest, AddrLabelExpr) {
+  auto t = makeWrappedStmts("lbl: &&lbl;", "lbl: &&lbl;", Lang_CXX03,
+                            addrLabelExpr());
+  EXPECT_TRUE(testStructuralMatch(t));
+TEST_F(StructuralEquivalenceStmtTest, AddrLabelExprDifferentLabel) {
+  auto t = makeWrappedStmts("lbl1: lbl2: &&lbl1;", "lbl1: lbl2: &&lbl2;",
+                            Lang_CXX03, addrLabelExpr());
+  // FIXME: Should be false. LabelDecl are incorrectly matched.
+  EXPECT_TRUE(testStructuralMatch(t));
+static const std::string MemoryOrderSrc = R"(
+enum memory_order {
+  memory_order_relaxed,
+  memory_order_consume,
+  memory_order_acquire,
+  memory_order_release,
+  memory_order_acq_rel,
+  memory_order_seq_cst
+TEST_F(StructuralEquivalenceStmtTest, AtomicExpr) {
+  std::string Prefix = "char a, b; " + MemoryOrderSrc;
+  auto t = makeStmts(
+      Prefix +
+          "void wrapped() { __atomic_load(&a, &b, memory_order_seq_cst); }",
+      Prefix +
+          "void wrapped() { __atomic_load(&a, &b, memory_order_seq_cst); }",
+      Lang_CXX03, atomicExpr());
+  EXPECT_TRUE(testStructuralMatch(t));
+TEST_F(StructuralEquivalenceStmtTest, AtomicExprDifferentOp) {
+  std::string Prefix = "char a, b; " + MemoryOrderSrc;
+  auto t = makeStmts(
+      Prefix +
+          "void wrapped() { __atomic_load(&a, &b, memory_order_seq_cst); }",
+      Prefix +
+          "void wrapped() { __atomic_store(&a, &b, memory_order_seq_cst); }",
+      Lang_CXX03, atomicExpr());
+  EXPECT_FALSE(testStructuralMatch(t));
+TEST_F(StructuralEquivalenceStmtTest, BinaryOperator) {
+  auto t = makeWrappedStmts("1 + 1", "1 + 1", Lang_CXX03, binaryOperator());
+  EXPECT_TRUE(testStructuralMatch(t));
+TEST_F(StructuralEquivalenceStmtTest, BinaryOperatorDifferentOps) {
+  auto t = makeWrappedStmts("1 + 1", "1 - 1", Lang_CXX03, binaryOperator());
+  EXPECT_FALSE(testStructuralMatch(t));
+TEST_F(StructuralEquivalenceStmtTest, CallExpr) {
+  std::string Src = "int call(); int wrapped() { call(); }";
+  auto t = makeStmts(Src, Src, Lang_CXX03, callExpr());
+  EXPECT_TRUE(testStructuralMatch(t));
+TEST_F(StructuralEquivalenceStmtTest, CallExprDifferentCallee) {
+  std::string FunctionSrc = "int func1(); int func2();\n";
+  auto t = makeStmts(FunctionSrc + "void wrapper() { func1(); }",
+                     FunctionSrc + "void wrapper() { func2(); }", Lang_CXX03,
+                     callExpr());
+  EXPECT_FALSE(testStructuralMatch(t));
+TEST_F(StructuralEquivalenceStmtTest, CharacterLiteral) {
+  auto t = makeWrappedStmts("'a'", "'a'", Lang_CXX03, characterLiteral());
+  EXPECT_TRUE(testStructuralMatch(t));
+TEST_F(StructuralEquivalenceStmtTest, CharacterLiteralDifferentValues) {
+  auto t = makeWrappedStmts("'a'", "'b'", Lang_CXX03, characterLiteral());
+  EXPECT_FALSE(testStructuralMatch(t));
+TEST_F(StructuralEquivalenceStmtTest, ExpressionTraitExpr) {
+  auto t = makeWrappedStmts("__is_lvalue_expr(1)", "__is_lvalue_expr(1)",
+                            Lang_CXX03, fallbackExprMatcher());
+  EXPECT_TRUE(testStructuralMatch(t));
+TEST_F(StructuralEquivalenceStmtTest, ExpressionTraitExprDifferentKind) {
+  auto t = makeWrappedStmts("__is_lvalue_expr(1)", "__is_rvalue_expr(1)",
+                            Lang_CXX03, fallbackExprMatcher());
+  EXPECT_FALSE(testStructuralMatch(t));
+TEST_F(StructuralEquivalenceStmtTest, FloatingLiteral) {
+  auto t = makeWrappedStmts("1.0", "1.0", Lang_CXX03, fallbackExprMatcher());
+  EXPECT_TRUE(testStructuralMatch(t));
+TEST_F(StructuralEquivalenceStmtTest, FloatingLiteralDifferentSpelling) {
+  auto t = makeWrappedStmts("0x10.1p0", "16.0625", Lang_CXX17,
+                            fallbackExprMatcher());
+  // Same value but with different spelling is equivalent.
+  EXPECT_TRUE(testStructuralMatch(t));
+TEST_F(StructuralEquivalenceStmtTest, FloatingLiteralDifferentType) {
+  auto t = makeWrappedStmts("1.0", "1.0f", Lang_CXX03, fallbackExprMatcher());
+  EXPECT_FALSE(testStructuralMatch(t));
+TEST_F(StructuralEquivalenceStmtTest, FloatingLiteralDifferentValue) {
+  auto t = makeWrappedStmts("1.01", "1.0", Lang_CXX03, fallbackExprMatcher());
+  EXPECT_FALSE(testStructuralMatch(t));
+TEST_F(StructuralEquivalenceStmtTest, IntegerLiteral) {
+  auto t = makeWrappedStmts("1", "1", Lang_CXX03, integerLiteral());
+  EXPECT_TRUE(testStructuralMatch(t));
+TEST_F(StructuralEquivalenceStmtTest, IntegerLiteralDifferentSpelling) {
+  auto t = makeWrappedStmts("1", "0x1", Lang_CXX03, integerLiteral());
+  // Same value but with different spelling is equivalent.
+  EXPECT_TRUE(testStructuralMatch(t));
+TEST_F(StructuralEquivalenceStmtTest, IntegerLiteralDifferentValue) {
+  auto t = makeWrappedStmts("1", "2", Lang_CXX03, integerLiteral());
+  EXPECT_FALSE(testStructuralMatch(t));
+TEST_F(StructuralEquivalenceStmtTest, IntegerLiteralDifferentTypes) {
+  auto t = makeWrappedStmts("1", "1L", Lang_CXX03, integerLiteral());
+  EXPECT_FALSE(testStructuralMatch(t));
+TEST_F(StructuralEquivalenceStmtTest, ObjCStringLiteral) {
+  auto t =
+      makeWrappedStmts("@\"a\"", "@\"a\"", Lang_OBJCXX, fallbackExprMatcher());
+  EXPECT_TRUE(testStructuralMatch(t));
+TEST_F(StructuralEquivalenceStmtTest, ObjCStringLiteralDifferentContent) {
+  auto t =
+      makeWrappedStmts("@\"a\"", "@\"b\"", Lang_OBJCXX, fallbackExprMatcher());
+  EXPECT_FALSE(testStructuralMatch(t));
+TEST_F(StructuralEquivalenceStmtTest, StringLiteral) {
+  auto t = makeWrappedStmts("\"a\"", "\"a\"", Lang_CXX03, stringLiteral());
+  EXPECT_TRUE(testStructuralMatch(t));
+TEST_F(StructuralEquivalenceStmtTest, StringLiteralDifferentContent) {
+  auto t = makeWrappedStmts("\"a\"", "\"b\"", Lang_CXX03, stringLiteral());
+  EXPECT_FALSE(testStructuralMatch(t));
+TEST_F(StructuralEquivalenceStmtTest, StringLiteralDifferentLength) {
+  auto t = makeWrappedStmts("\"a\"", "\"aa\"", Lang_CXX03, stringLiteral());
+  EXPECT_FALSE(testStructuralMatch(t));
+TEST_F(StructuralEquivalenceStmtTest, TypeTraitExpr) {
+  auto t = makeWrappedStmts("__is_pod(int)", "__is_pod(int)", Lang_CXX03,
+                            fallbackExprMatcher());
+  EXPECT_TRUE(testStructuralMatch(t));
+TEST_F(StructuralEquivalenceStmtTest, TypeTraitExprDifferentType) {
+  auto t = makeWrappedStmts("__is_pod(int)", "__is_pod(long)", Lang_CXX03,
+                            fallbackExprMatcher());
+  EXPECT_FALSE(testStructuralMatch(t));
+TEST_F(StructuralEquivalenceStmtTest, TypeTraitExprDifferentTrait) {
+  auto t = makeWrappedStmts(
+      "__is_pod(int)", "__is_trivially_constructible(int)", Lang_CXX03, expr());
+  EXPECT_FALSE(testStructuralMatch(t));
+TEST_F(StructuralEquivalenceStmtTest, TypeTraitExprDifferentTraits) {
+  auto t = makeWrappedStmts("__is_constructible(int)",
+                            "__is_constructible(int, int)", Lang_CXX03, expr());
+  EXPECT_FALSE(testStructuralMatch(t));
+TEST_F(StructuralEquivalenceStmtTest, UnaryExprOrTypeTraitExpr) {
+  auto t = makeWrappedStmts("sizeof(int)", "sizeof(int)", Lang_CXX03,
+                            unaryExprOrTypeTraitExpr());
+  EXPECT_TRUE(testStructuralMatch(t));
+TEST_F(StructuralEquivalenceStmtTest, UnaryExprOrTypeTraitExprDifferentKind) {
+  auto t = makeWrappedStmts("sizeof(int)", "alignof(long)", Lang_CXX11,
+                            unaryExprOrTypeTraitExpr());
+  EXPECT_FALSE(testStructuralMatch(t));
+TEST_F(StructuralEquivalenceStmtTest, UnaryExprOrTypeTraitExprDifferentType) {
+  auto t = makeWrappedStmts("sizeof(int)", "sizeof(long)", Lang_CXX03,
+                            unaryExprOrTypeTraitExpr());
+  EXPECT_FALSE(testStructuralMatch(t));
+TEST_F(StructuralEquivalenceStmtTest, UnaryOperator) {
+  auto t = makeWrappedStmts("+1", "+1", Lang_CXX03, unaryOperator());
+  EXPECT_TRUE(testStructuralMatch(t));
+TEST_F(StructuralEquivalenceStmtTest, UnaryOperatorDifferentOps) {
+  auto t = makeWrappedStmts("+1", "-1", Lang_CXX03, unaryOperator());
+  EXPECT_FALSE(testStructuralMatch(t));
 } // end namespace ast_matchers
 } // end namespace clang
Index: clang/lib/AST/ASTStructuralEquivalence.cpp
--- clang/lib/AST/ASTStructuralEquivalence.cpp
+++ clang/lib/AST/ASTStructuralEquivalence.cpp
@@ -68,7 +68,12 @@
 #include "clang/AST/DeclObjC.h"
 #include "clang/AST/DeclTemplate.h"
 #include "clang/AST/ExprCXX.h"
+#include "clang/AST/ExprConcepts.h"
+#include "clang/AST/ExprObjC.h"
+#include "clang/AST/ExprOpenMP.h"
 #include "clang/AST/NestedNameSpecifier.h"
+#include "clang/AST/StmtObjC.h"
+#include "clang/AST/StmtOpenMP.h"
 #include "clang/AST/TemplateBase.h"
 #include "clang/AST/TemplateName.h"
 #include "clang/AST/Type.h"
@@ -149,32 +154,230 @@
   return true;
-/// Determine structural equivalence of two expressions.
-static bool IsStructurallyEquivalent(StructuralEquivalenceContext &Context,
-                                     const Expr *E1, const Expr *E2) {
-  if (!E1 || !E2)
-    return E1 == E2;
+namespace {
+/// Encapsulates Stmt comparison logic.
+class StmtComparer {
+  StructuralEquivalenceContext &Context;
+  // IsStmtEquivalent overloads. Each overload compares a specific statement
+  // and only has to compare the data that is specific to the specific statement
+  // class. Should only be called from TraverseStmt.
+  bool IsStmtEquivalent(const AddrLabelExpr *E1, const AddrLabelExpr *E2) {
+    return IsStructurallyEquivalent(Context, E1->getLabel(), E2->getLabel());
+  }
+  bool IsStmtEquivalent(const AtomicExpr *E1, const AtomicExpr *E2) {
+    return E1->getOp() == E2->getOp();
+  }
+  bool IsStmtEquivalent(const BinaryOperator *E1, const BinaryOperator *E2) {
+    return E1->getOpcode() == E2->getOpcode();
+  }
-  if (auto *DE1 = dyn_cast<DependentScopeDeclRefExpr>(E1)) {
-    auto *DE2 = dyn_cast<DependentScopeDeclRefExpr>(E2);
-    if (!DE2)
+  bool IsStmtEquivalent(const CallExpr *E1, const CallExpr *E2) {
+    // FIXME: IsStructurallyEquivalent requires non-const Decls.
+    Decl *Callee1 = const_cast<Decl *>(E1->getCalleeDecl());
+    Decl *Callee2 = const_cast<Decl *>(E2->getCalleeDecl());
+    // Compare whether both calls know their callee.
+    if (static_cast<bool>(Callee1) != static_cast<bool>(Callee2))
       return false;
+    // Both calls have no callee, so nothing to do.
+    if (!static_cast<bool>(Callee1))
+      return true;
+    assert(Callee2);
+    return IsStructurallyEquivalent(Context, Callee1, Callee2);
+  }
+  bool IsStmtEquivalent(const CharacterLiteral *E1,
+                        const CharacterLiteral *E2) {
+    return E1->getValue() == E2->getValue() && E1->getKind() == E2->getKind();
+  }
+  bool IsStmtEquivalent(const ChooseExpr *E1, const ChooseExpr *E2) {
+    return true; // Semantics only depend on children.
+  }
+  bool IsStmtEquivalent(const CompoundStmt *E1, const CompoundStmt *E2) {
+    // Number of children is actually checked by the generic children comparison
+    // code, but a CompoundStmt is one of the few statements where the number of
+    // children frequently differs and the number of statements is also always
+    // precomputed. Directly comparing the number of children here is thus
+    // just an optimization.
+    return E1->size() == E2->size();
+  }
+  bool IsStmtEquivalent(const DependentScopeDeclRefExpr *DE1,
+                        const DependentScopeDeclRefExpr *DE2) {
     if (!IsStructurallyEquivalent(Context, DE1->getDeclName(),
       return false;
     return IsStructurallyEquivalent(Context, DE1->getQualifier(),
-  } else if (auto CastE1 = dyn_cast<ImplicitCastExpr>(E1)) {
-    auto *CastE2 = dyn_cast<ImplicitCastExpr>(E2);
-    if (!CastE2)
+  }
+  bool IsStmtEquivalent(const Expr *E1, const Expr *E2) {
+    return IsStructurallyEquivalent(Context, E1->getType(), E2->getType());
+  }
+  bool IsStmtEquivalent(const ExpressionTraitExpr *E1,
+                        const ExpressionTraitExpr *E2) {
+    return E1->getTrait() == E2->getTrait() && E1->getValue() == E2->getValue();
+  }
+  bool IsStmtEquivalent(const FloatingLiteral *E1, const FloatingLiteral *E2) {
+    return E1->isExact() == E2->isExact() && E1->getValue() == E2->getValue();
+  }
+  bool IsStmtEquivalent(const ImplicitCastExpr *CastE1,
+                        const ImplicitCastExpr *CastE2) {
+    return IsStructurallyEquivalent(Context, CastE1->getType(),
+                                    CastE2->getType());
+  }
+  bool IsStmtEquivalent(const IntegerLiteral *E1, const IntegerLiteral *E2) {
+    return E1->getValue() == E2->getValue();
+  }
+  bool IsStmtEquivalent(const ObjCStringLiteral *E1,
+                        const ObjCStringLiteral *E2) {
+    // Just wraps a StringLiteral child.
+    return true;
+  }
+  bool IsStmtEquivalent(const Stmt *S1, const Stmt *S2) { return true; }
+  bool IsStmtEquivalent(const SourceLocExpr *E1, const SourceLocExpr *E2) {
+    return E1->getIdentKind() == E2->getIdentKind();
+  }
+  bool IsStmtEquivalent(const StmtExpr *E1, const StmtExpr *E2) {
+    return E1->getTemplateDepth() == E2->getTemplateDepth();
+  }
+  bool IsStmtEquivalent(const StringLiteral *E1, const StringLiteral *E2) {
+    return E1->getBytes() == E2->getBytes();
+  }
+  bool IsStmtEquivalent(const SubstNonTypeTemplateParmExpr *E1,
+                        const SubstNonTypeTemplateParmExpr *E2) {
+    return IsStructurallyEquivalent(Context, E1->getParameter(),
+                                    E2->getParameter());
+  }
+  bool IsStmtEquivalent(const SubstNonTypeTemplateParmPackExpr *E1,
+                        const SubstNonTypeTemplateParmPackExpr *E2) {
+    return IsStructurallyEquivalent(Context, E1->getArgumentPack(),
+                                    E2->getArgumentPack());
+  }
+  bool IsStmtEquivalent(const TypeTraitExpr *E1, const TypeTraitExpr *E2) {
+    if (E1->getTrait() != E2->getTrait())
+      return false;
+    for (auto Pair : zip_longest(E1->getArgs(), E2->getArgs())) {
+      Optional<TypeSourceInfo *> Child1 = std::get<0>(Pair);
+      Optional<TypeSourceInfo *> Child2 = std::get<1>(Pair);
+      // Different number of args.
+      if (!Child1 || !Child2)
+        return false;
+      if (!IsStructurallyEquivalent(Context, (*Child1)->getType(),
+                                    (*Child2)->getType()))
+        return false;
+    }
+    return true;
+  }
+  bool IsStmtEquivalent(const UnaryExprOrTypeTraitExpr *E1,
+                        const UnaryExprOrTypeTraitExpr *E2) {
+    if (E1->getKind() != E2->getKind())
+      return false;
+    return IsStructurallyEquivalent(Context, E1->getTypeOfArgument(),
+                                    E2->getTypeOfArgument());
+  }
+  bool IsStmtEquivalent(const UnaryOperator *E1, const UnaryOperator *E2) {
+    return E1->getOpcode() == E2->getOpcode();
+  }
+  bool IsStmtEquivalent(const VAArgExpr *E1, const VAArgExpr *E2) {
+    // Semantics only depend on children.
+    return true;
+  }
+  /// End point of the traversal chain.
+  bool TraverseStmt(const Stmt *S1, const Stmt *S2) { return true; }
+  // Create traversal methods that traverse the class hierarchy and return
+  // the accumulated result of the comparison. Each TraverseStmt overload
+  // calls the TraverseStmt overload of the parent class. For example,
+  // the TraverseStmt overload for 'BinaryOperator' calls the TraverseStmt
+  // overload of 'Expr' which then calls the overload for 'Stmt'.
+#define STMT(CLASS, PARENT)                                                    \
+  bool TraverseStmt(const CLASS *S1, const CLASS *S2) {                        \
+    if (!TraverseStmt(static_cast<const PARENT *>(S1),                         \
+                      static_cast<const PARENT *>(S2)))                        \
+      return false;                                                            \
+    return IsStmtEquivalent(S1, S2);                                           \
+  }
+#include "clang/AST/"
+  StmtComparer(StructuralEquivalenceContext &C) : Context(C) {}
+  /// Determine whether two statements are equivalent. The statements have to
+  /// be of the same kind. The children of the statements and their properties
+  /// are not compared by this function.
+  bool IsEquivalent(const Stmt *S1, const Stmt *S2) {
+    if (S1->getStmtClass() != S2->getStmtClass())
+      return false;
+    // Each TraverseStmt walks the class hierarchy from the leaf class to
+    // the root class 'Stmt' (e.g. 'BinaryOperator' -> 'Expr' -> 'Stmt'). Cast
+    // the Stmt we have here to its specific subclass so that we call the
+    // overload that walks the whole class hierarchy from leaf to root (e.g.,
+    // cast to 'BinaryOperator' so that 'Expr' and 'Stmt' is traversed).
+    switch (S1->getStmtClass()) {
+    case Stmt::NoStmtClass:
+      llvm_unreachable("Can't traverse NoStmtClass");
+#define STMT(CLASS, PARENT)                                                    \
+  case Stmt::StmtClass::CLASS##Class:                                          \
+    return TraverseStmt(static_cast<const CLASS *>(S1),                        \
+                        static_cast<const CLASS *>(S2));
+#include "clang/AST/"
+    }
+    llvm_unreachable("Invalid statement kind");
+  }
+} // namespace
+/// Determine structural equivalence of two statements.
+static bool IsStructurallyEquivalent(StructuralEquivalenceContext &Context,
+                                     const Stmt *S1, const Stmt *S2) {
+  if (!S1 || !S2)
+    return S1 == S2;
+  // Compare the statements itself.
+  StmtComparer Comparer(Context);
+  if (!Comparer.IsEquivalent(S1, S2))
+    return false;
+  // Iterate over the children of both statements and also compare them.
+  for (auto Pair : zip_longest(S1->children(), S2->children())) {
+    Optional<const Stmt *> Child1 = std::get<0>(Pair);
+    Optional<const Stmt *> Child2 = std::get<1>(Pair);
+    // One of the statements has a different amount of children than the other,
+    // so the statements can't be equivalent.
+    if (!Child1 || !Child2)
       return false;
-    if (!IsStructurallyEquivalent(Context, CastE1->getType(),
-                                  CastE2->getType()))
+    if (!IsStructurallyEquivalent(Context, *Child1, *Child2))
       return false;
-    return IsStructurallyEquivalent(Context, CastE1->getSubExpr(),
-                                    CastE2->getSubExpr());
-  // FIXME: Handle other kind of expressions!
   return true;
@@ -1790,6 +1993,15 @@
   return !Finish();
+bool StructuralEquivalenceContext::IsEquivalent(Stmt *S1, Stmt *S2) {
+  assert(DeclsToCheck.empty());
+  assert(VisitedDecls.empty());
+  if (!::IsStructurallyEquivalent(*this, S1, S2))
+    return false;
+  return !Finish();
 bool StructuralEquivalenceContext::CheckCommonEquivalence(Decl *D1, Decl *D2) {
   // Check for equivalent described template.
   TemplateDecl *Template1 = D1->getDescribedTemplate();
Index: clang/include/clang/AST/ASTStructuralEquivalence.h
--- clang/include/clang/AST/ASTStructuralEquivalence.h
+++ clang/include/clang/AST/ASTStructuralEquivalence.h
@@ -97,6 +97,13 @@
   /// \c VisitedDecls members) and can cause faulty equivalent results.
   bool IsEquivalent(QualType T1, QualType T2);
+  /// Determine whether the two statements are structurally equivalent.
+  /// Implementation functions (all static functions in
+  /// ASTStructuralEquivalence.cpp) must never call this function because that
+  /// will wreak havoc the internal state (\c DeclsToCheck and
+  /// \c VisitedDecls members) and can cause faulty equivalent results.
+  bool IsEquivalent(Stmt *S1, Stmt *S2);
   /// Find the index of the given anonymous struct/union within its
   /// context.
cfe-commits mailing list

Reply via email to