teemperor created this revision.
teemperor added reviewers: zaks.anna, v.g.vassilev, doug.gregor, chandlerc.
teemperor added a subscriber: cfe-commits.

This patch adds postorder traversal support to the RecursiveASTVisitor.

This feature needs to be explicitly enabled by overriding 
shouldTraversePostOrder()
as it has performance drawbacks for the iterative Stmt-traversal.

http://reviews.llvm.org/D20382

Files:
  include/clang/AST/RecursiveASTVisitor.h
  unittests/AST/CMakeLists.txt
  unittests/AST/PostOrderASTVisitor.cpp

Index: unittests/AST/PostOrderASTVisitor.cpp
===================================================================
--- /dev/null
+++ unittests/AST/PostOrderASTVisitor.cpp
@@ -0,0 +1,112 @@
+//===- unittests/AST/PostOrderASTVisitor.cpp - Declaration printer tests --===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains tests for the post-order traversing functionality
+// of RecursiveASTVisitor.
+//
+//===----------------------------------------------------------------------===//
+
+#include "gtest/gtest.h"
+#include "clang/AST/RecursiveASTVisitor.h"
+#include "clang/Tooling/Tooling.h"
+
+using namespace clang;
+
+namespace {
+
+  class RecordingVisitor
+    : public RecursiveASTVisitor<RecordingVisitor> {
+
+    bool VisitPostOrder;
+  public:
+    explicit RecordingVisitor(bool VisitPostOrder)
+      : VisitPostOrder(VisitPostOrder) {
+    }
+
+    // List of visited nodes during traversal.
+    std::vector<std::string> VisitedNodes;
+
+    bool shouldTraversePostOrder() const { return VisitPostOrder; }
+
+    bool PostVisitBinaryOperator(BinaryOperator *Op) {
+      VisitedNodes.push_back(Op->getOpcodeStr());
+      return true;
+    }
+
+    bool PostVisitIntegerLiteral(IntegerLiteral *Lit) {
+      VisitedNodes.push_back(Lit->getValue().toString(10, false));
+      return true;
+    }
+
+    bool PostVisitCXXMethodDecl(CXXMethodDecl *D) {
+      VisitedNodes.push_back(D->getQualifiedNameAsString());
+      return true;
+    }
+
+    bool PostVisitReturnStmt(Stmt *S) {
+      VisitedNodes.push_back("return");
+      return true;
+    }
+
+    bool PostVisitCXXRecordDecl(CXXRecordDecl *Declaration) {
+      VisitedNodes.push_back(Declaration->getQualifiedNameAsString());
+      return true;
+    }
+
+    bool PostVisitTemplateTypeParmType(TemplateTypeParmType *T) {
+      VisitedNodes.push_back(T->getDecl()->getQualifiedNameAsString());
+      return true;
+    }
+  };
+
+}
+
+TEST(RecursiveASTVisitor, PostOrderTraversal) {
+  auto ASTUnit = tooling::buildASTFromCode(
+    "template <class T> class A {"
+    "  class B {"
+    "    int foo() { return 1 + 2; }"
+    "  };"
+    "};"
+  );
+  auto TU = ASTUnit->getASTContext().getTranslationUnitDecl();
+  // We traverse the translation unit and store all
+  // visited nodes.
+  RecordingVisitor Visitor(true);
+  Visitor.TraverseTranslationUnitDecl(TU);
+
+  std::vector<std::string> expected = {
+    "1", "2", "+", "return", "A::B::foo", "A::B", "A", "A::T"
+  };
+  // Compare the list of actually visited nodes
+  // with the expected list of visited nodes.
+  ASSERT_EQ(expected.size(), Visitor.VisitedNodes.size());
+  for (std::size_t I = 0; I < expected.size(); I++) {
+    ASSERT_EQ(expected[I], Visitor.VisitedNodes[I]);
+  }
+}
+
+TEST(RecursiveASTVisitor, DeactivatePostOrderTraversal) {
+  auto ASTUnit = tooling::buildASTFromCode(
+    "template <class T> class A {"
+    "  class B {"
+    "    int foo() { return 1 + 2; }"
+    "  };"
+    "};"
+  );
+  auto TU = ASTUnit->getASTContext().getTranslationUnitDecl();
+  // We try to traverse the translation unit but with deactivated
+  // post order calls.
+  RecordingVisitor Visitor(false);
+  Visitor.TraverseTranslationUnitDecl(TU);
+
+  // We deactivated postorder traversal, so we shouldn't have
+  // recorded any nodes.
+  ASSERT_TRUE(Visitor.VisitedNodes.empty());
+}
Index: unittests/AST/CMakeLists.txt
===================================================================
--- unittests/AST/CMakeLists.txt
+++ unittests/AST/CMakeLists.txt
@@ -14,6 +14,7 @@
   EvaluateAsRValueTest.cpp
   ExternalASTSourceTest.cpp
   NamedDeclPrinterTest.cpp
+  PostOrderASTVisitor.cpp
   SourceLocationTest.cpp
   StmtPrinterTest.cpp
   )
Index: include/clang/AST/RecursiveASTVisitor.h
===================================================================
--- include/clang/AST/RecursiveASTVisitor.h
+++ include/clang/AST/RecursiveASTVisitor.h
@@ -72,8 +72,8 @@
       return false;                                                            \
   } while (0)
 
-/// \brief A class that does preorder depth-first traversal on the
-/// entire Clang AST and visits each node.
+/// \brief A class that does preorder (and optional postorder)
+/// depth-first traversal on the entire Clang AST and visits each node.
 ///
 /// This class performs three distinct tasks:
 ///   1. traverse the AST (i.e. go to each node);
@@ -133,6 +133,14 @@
 /// to return true, in which case all known implicit and explicit
 /// instantiations will be visited at the same time as the pattern
 /// from which they were produced.
+///
+/// By default, this visitor preorder traverses the AST. If postorder traversal
+/// is needed, the \c shouldTraversePostOrder method needs to be overriden
+/// to return \c true .
+/// The visitor will then call the PostWalkUpFromFoo(Foo *x)
+/// and PostVisitFoo(Foo *x) methods which behave in the same way as their
+/// counterparts without the 'Post...'-prefix beside the fact that they
+/// are called as if the AST is traversed in postorder.
 template <typename Derived> class RecursiveASTVisitor {
 public:
   /// A queue used for performing data recursion over statements.
@@ -158,6 +166,10 @@
   /// code, e.g., implicit constructors and destructors.
   bool shouldVisitImplicitCode() const { return false; }
 
+  /// \brief Return whether this visitor should call PostVisit
+  /// and PostWalkUpFrom functions.
+  bool shouldTraversePostOrder() const { return false; }
+
   /// \brief Recursively visit a statement or expression, by
   /// dispatching to Traverse*() based on the argument's dynamic type.
   ///
@@ -342,14 +354,27 @@
   bool Visit##CLASS(CLASS *S) { return true; }
 #include "clang/AST/StmtNodes.inc"
 
+
+  // Define PostWalkUpFrom*() and empty PostVisit*() for all Stmt classes.
+  bool PostWalkUpFromStmt(Stmt *S) { return getDerived().PostVisitStmt(S); }
+  bool PostVisitStmt(Stmt *S) { return true; }
+#define STMT(CLASS, PARENT)                                                    \
+  bool PostWalkUpFrom##CLASS(CLASS *S) {                                       \
+    TRY_TO(PostWalkUpFrom##PARENT(S));                                         \
+    TRY_TO(PostVisit##CLASS(S));                                               \
+    return true;                                                               \
+  }                                                                            \
+  bool PostVisit##CLASS(CLASS *S) { return true; }
+#include "clang/AST/StmtNodes.inc"
+
 // Define Traverse*(), WalkUpFrom*(), and Visit*() for unary
 // operator methods.  Unary operators are not classes in themselves
 // (they're all opcodes in UnaryOperator) but do have visitors.
 #define OPERATOR(NAME)                                                         \
   bool TraverseUnary##NAME(UnaryOperator *S,                                   \
                            DataRecursionQueue *Queue = nullptr) {              \
     TRY_TO(WalkUpFromUnary##NAME(S));                                          \
-    TRY_TO_TRAVERSE_OR_ENQUEUE_STMT(S->getSubExpr());                                    \
+    TRY_TO_TRAVERSE_OR_ENQUEUE_STMT(S->getSubExpr());                          \
     return true;                                                               \
   }                                                                            \
   bool WalkUpFromUnary##NAME(UnaryOperator *S) {                               \
@@ -368,8 +393,8 @@
 #define GENERAL_BINOP_FALLBACK(NAME, BINOP_TYPE)                               \
   bool TraverseBin##NAME(BINOP_TYPE *S, DataRecursionQueue *Queue = nullptr) { \
     TRY_TO(WalkUpFromBin##NAME(S));                                            \
-    TRY_TO_TRAVERSE_OR_ENQUEUE_STMT(S->getLHS());                                        \
-    TRY_TO_TRAVERSE_OR_ENQUEUE_STMT(S->getRHS());                                        \
+    TRY_TO_TRAVERSE_OR_ENQUEUE_STMT(S->getLHS());                              \
+    TRY_TO_TRAVERSE_OR_ENQUEUE_STMT(S->getRHS());                              \
     return true;                                                               \
   }                                                                            \
   bool WalkUpFromBin##NAME(BINOP_TYPE *S) {                                    \
@@ -415,6 +440,20 @@
   bool Visit##CLASS##Type(CLASS##Type *T) { return true; }
 #include "clang/AST/TypeNodes.def"
 
+
+
+  // Define PostWalkUpFrom*() and empty Visit*() for all Type classes.
+  bool PostWalkUpFromType(Type *T) { return getDerived().PostVisitType(T); }
+  bool PostVisitType(Type *T) { return true; }
+#define TYPE(CLASS, BASE)                                                      \
+  bool PostWalkUpFrom##CLASS##Type(CLASS##Type *T) {                           \
+    TRY_TO(PostWalkUpFrom##BASE(T));                                           \
+    TRY_TO(PostVisit##CLASS##Type(T));                                         \
+    return true;                                                               \
+  }                                                                            \
+  bool PostVisit##CLASS##Type(CLASS##Type *T) { return true; }
+#include "clang/AST/TypeNodes.def"
+
 // ---- Methods on TypeLocs ----
 // FIXME: this currently just calls the matching Type methods
 
@@ -469,6 +508,18 @@
   bool Visit##CLASS##Decl(CLASS##Decl *D) { return true; }
 #include "clang/AST/DeclNodes.inc"
 
+  // Define PostWalkUpFrom*() and empty PostVisit*() for all Decl classes.
+  bool PostWalkUpFromDecl(Decl *D) { return getDerived().PostVisitDecl(D); }
+  bool PostVisitDecl(Decl *D) { return true; }
+#define DECL(CLASS, BASE)                                                      \
+  bool PostWalkUpFrom##CLASS##Decl(CLASS##Decl *D) {                           \
+    TRY_TO(PostWalkUpFrom##BASE(D));                                           \
+    TRY_TO(PostVisit##CLASS##Decl(D));                                         \
+    return true;                                                               \
+  }                                                                            \
+  bool PostVisit##CLASS##Decl(CLASS##Decl *D) { return true; }
+#include "clang/AST/DeclNodes.inc"
+
 private:
   // These are helper methods used by more than one Traverse* method.
   bool TraverseTemplateParameterListHelper(TemplateParameterList *TPL);
@@ -499,6 +550,7 @@
   bool VisitOMPClauseWithPostUpdate(OMPClauseWithPostUpdate *Node);
 
   bool dataTraverseNode(Stmt *S, DataRecursionQueue *Queue);
+  bool PostVisitNode(Stmt *S);
 };
 
 template <typename Derived>
@@ -557,8 +609,29 @@
 #undef DISPATCH_STMT
 
 template <typename Derived>
+bool RecursiveASTVisitor<Derived>::PostVisitNode(Stmt *S) {
+
+  // Top switch stmt: dispatch to PostWalkUpFromFooStmt for each concrete
+  // FooStmt.
+  switch (S->getStmtClass()) {
+  case Stmt::NoStmtClass:
+    break;
+#define ABSTRACT_STMT(STMT)
+#define STMT(CLASS, PARENT)                                                    \
+  case Stmt::CLASS##Class:                                                     \
+    PostWalkUpFrom##CLASS(static_cast<CLASS *>(S)); break;
+#include "clang/AST/StmtNodes.inc"
+  }
+
+  return true;
+}
+
+#undef DISPATCH_STMT
+
+template <typename Derived>
 bool RecursiveASTVisitor<Derived>::TraverseStmt(Stmt *S,
                                                 DataRecursionQueue *Queue) {
+
   if (!S)
     return true;
 
@@ -570,6 +643,9 @@
   SmallVector<llvm::PointerIntPair<Stmt *, 1, bool>, 8> LocalQueue;
   LocalQueue.push_back({S, false});
 
+  SmallVector<Stmt *, 16> ReverseLocalQueue;
+  ReverseLocalQueue.push_back(S);
+
   while (!LocalQueue.empty()) {
     auto &CurrSAndVisited = LocalQueue.back();
     Stmt *CurrS = CurrSAndVisited.getPointer();
@@ -586,11 +662,25 @@
       TRY_TO(dataTraverseNode(CurrS, &LocalQueue));
       // Process new children in the order they were added.
       std::reverse(LocalQueue.begin() + N, LocalQueue.end());
+
+      if (getDerived().shouldTraversePostOrder()) {
+        for (std::size_t i = N; i < LocalQueue.size(); ++i) {
+          ReverseLocalQueue.push_back(LocalQueue[i].getPointer());
+        }
+      }
     } else {
       LocalQueue.pop_back();
     }
   }
 
+  if (getDerived().shouldTraversePostOrder()) {
+    for (auto Iter = ReverseLocalQueue.rbegin();
+         Iter != ReverseLocalQueue.rend(); ++Iter) {
+      TRY_TO(PostVisitNode(*Iter));
+    }
+  }
+
+
   return true;
 }
 
@@ -876,6 +966,8 @@
   bool RecursiveASTVisitor<Derived>::Traverse##TYPE(TYPE *T) {                 \
     TRY_TO(WalkUpFrom##TYPE(T));                                               \
     { CODE; }                                                                  \
+    if (getDerived().shouldTraversePostOrder())                                \
+      TRY_TO(PostWalkUpFrom##TYPE(T));                                         \
     return true;                                                               \
   }
 
@@ -1278,10 +1370,15 @@
 #define DEF_TRAVERSE_DECL(DECL, CODE)                                          \
   template <typename Derived>                                                  \
   bool RecursiveASTVisitor<Derived>::Traverse##DECL(DECL *D) {                 \
+    bool ShouldVisitChildren = true;                                           \
+    bool ReturnValue = true;                                                   \
     TRY_TO(WalkUpFrom##DECL(D));                                               \
     { CODE; }                                                                  \
-    TRY_TO(TraverseDeclContextHelper(dyn_cast<DeclContext>(D)));               \
-    return true;                                                               \
+    if (ReturnValue && ShouldVisitChildren)                                    \
+      TRY_TO(TraverseDeclContextHelper(dyn_cast<DeclContext>(D)));             \
+    if (ReturnValue && getDerived().shouldTraversePostOrder())                 \
+      TRY_TO(PostWalkUpFrom##DECL(D));                                         \
+    return ReturnValue;                                                        \
   }
 
 DEF_TRAVERSE_DECL(AccessSpecDecl, {})
@@ -1295,18 +1392,12 @@
       TRY_TO(TraverseStmt(I.getCopyExpr()));
     }
   }
-  // This return statement makes sure the traversal of nodes in
-  // decls_begin()/decls_end() (done in the DEF_TRAVERSE_DECL macro)
-  // is skipped - don't remove it.
-  return true;
+  ShouldVisitChildren = false;
 })
 
 DEF_TRAVERSE_DECL(CapturedDecl, {
   TRY_TO(TraverseStmt(D->getBody()));
-  // This return statement makes sure the traversal of nodes in
-  // decls_begin()/decls_end() (done in the DEF_TRAVERSE_DECL macro)
-  // is skipped - don't remove it.
-  return true;
+  ShouldVisitChildren = false;
 })
 
 DEF_TRAVERSE_DECL(EmptyDecl, {})
@@ -1376,11 +1467,7 @@
 
   // We shouldn't traverse an aliased namespace, since it will be
   // defined (and, therefore, traversed) somewhere else.
-  //
-  // This return statement makes sure the traversal of nodes in
-  // decls_begin()/decls_end() (done in the DEF_TRAVERSE_DECL macro)
-  // is skipped - don't remove it.
-  return true;
+  ShouldVisitChildren = false;
 })
 
 DEF_TRAVERSE_DECL(LabelDecl, {// There is no code in a LabelDecl.
@@ -1436,7 +1523,7 @@
   if (D->isThisDeclarationADefinition()) {
     TRY_TO(TraverseStmt(D->getBody()));
   }
-  return true;
+  ShouldVisitChildren = false;
 })
 
 DEF_TRAVERSE_DECL(ObjCTypeParamDecl, {
@@ -1453,7 +1540,7 @@
     TRY_TO(TraverseTypeLoc(D->getTypeSourceInfo()->getTypeLoc()));
   else
     TRY_TO(TraverseType(D->getType()));
-  return true;
+  ShouldVisitChildren = false;
 })
 
 DEF_TRAVERSE_DECL(UsingDecl, {
@@ -1854,33 +1941,38 @@
 DEF_TRAVERSE_DECL(FunctionDecl, {
   // We skip decls_begin/decls_end, which are already covered by
   // TraverseFunctionHelper().
-  return TraverseFunctionHelper(D);
+  ShouldVisitChildren = false;
+  ReturnValue = TraverseFunctionHelper(D);
 })
 
 DEF_TRAVERSE_DECL(CXXMethodDecl, {
   // We skip decls_begin/decls_end, which are already covered by
   // TraverseFunctionHelper().
-  return TraverseFunctionHelper(D);
+  ShouldVisitChildren = false;
+  ReturnValue = TraverseFunctionHelper(D);
 })
 
 DEF_TRAVERSE_DECL(CXXConstructorDecl, {
   // We skip decls_begin/decls_end, which are already covered by
   // TraverseFunctionHelper().
-  return TraverseFunctionHelper(D);
+  ShouldVisitChildren = false;
+  ReturnValue = TraverseFunctionHelper(D);
 })
 
 // CXXConversionDecl is the declaration of a type conversion operator.
 // It's not a cast expression.
 DEF_TRAVERSE_DECL(CXXConversionDecl, {
   // We skip decls_begin/decls_end, which are already covered by
   // TraverseFunctionHelper().
-  return TraverseFunctionHelper(D);
+  ShouldVisitChildren = false;
+  ReturnValue = TraverseFunctionHelper(D);
 })
 
 DEF_TRAVERSE_DECL(CXXDestructorDecl, {
   // We skip decls_begin/decls_end, which are already covered by
   // TraverseFunctionHelper().
-  return TraverseFunctionHelper(D);
+  ShouldVisitChildren = false;
+  ReturnValue = TraverseFunctionHelper(D);
 })
 
 template <typename Derived>
@@ -1932,12 +2024,18 @@
   template <typename Derived>                                                  \
   bool RecursiveASTVisitor<Derived>::Traverse##STMT(                           \
       STMT *S, DataRecursionQueue *Queue) {                                    \
+    bool ShouldVisitChildren = true;                                           \
+    bool ReturnValue = true;                                                   \
     TRY_TO(WalkUpFrom##STMT(S));                                               \
     { CODE; }                                                                  \
-    for (Stmt *SubStmt : S->children()) {                                      \
-      TRY_TO_TRAVERSE_OR_ENQUEUE_STMT(SubStmt);                                          \
+    if (ShouldVisitChildren) {                                                 \
+      for (Stmt *SubStmt : S->children()) {                                    \
+        TRY_TO_TRAVERSE_OR_ENQUEUE_STMT(SubStmt);                              \
+      }                                                                        \
     }                                                                          \
-    return true;                                                               \
+    if (!Queue && ReturnValue && getDerived().shouldTraversePostOrder())       \
+      TRY_TO(PostWalkUpFrom##STMT(S));                                         \
+    return ReturnValue;                                                        \
   }
 
 DEF_TRAVERSE_STMT(GCCAsmStmt, {
@@ -1974,7 +2072,7 @@
   // initializer]'.  The decls above already traverse over the
   // initializers, so we don't have to do it again (which
   // children() would do).
-  return true;
+  ShouldVisitChildren = false;
 })
 
 // These non-expr stmts (most of them), do not need any action except
@@ -2006,7 +2104,7 @@
     TRY_TO_TRAVERSE_OR_ENQUEUE_STMT(S->getRangeInit());
     TRY_TO_TRAVERSE_OR_ENQUEUE_STMT(S->getBody());
     // Visit everything else only if shouldVisitImplicitCode().
-    return true;
+    ShouldVisitChildren = false;
   }
 })
 DEF_TRAVERSE_STMT(MSDependentExistsStmt, {
@@ -2103,7 +2201,7 @@
       S->isSemanticForm() ? S->getSyntacticForm() : S, Queue));
   TRY_TO(TraverseSynOrSemInitListExpr(
       S->isSemanticForm() ? S : S->getSemanticForm(), Queue));
-  return true;
+  ShouldVisitChildren = false;
 })
 
 // GenericSelectionExpr is a special case because the types and expressions
@@ -2116,7 +2214,7 @@
       TRY_TO(TraverseTypeLoc(TS->getTypeLoc()));
     TRY_TO_TRAVERSE_OR_ENQUEUE_STMT(S->getAssocExpr(i));
   }
-  return true;
+  ShouldVisitChildren = false;
 })
 
 // PseudoObjectExpr is a special case because of the weirdness with
@@ -2131,7 +2229,7 @@
       sub = OVE->getSourceExpr();
     TRY_TO_TRAVERSE_OR_ENQUEUE_STMT(sub);
   }
-  return true;
+  ShouldVisitChildren = false;
 })
 
 DEF_TRAVERSE_STMT(CXXScalarValueInitExpr, {
@@ -2235,7 +2333,8 @@
       TRY_TO_TRAVERSE_OR_ENQUEUE_STMT(NE);
   }
 
-  return TRAVERSE_STMT_BASE(LambdaBody, LambdaExpr, S, Queue);
+  ReturnValue = TRAVERSE_STMT_BASE(LambdaBody, LambdaExpr, S, Queue);
+  ShouldVisitChildren = false;
 })
 
 DEF_TRAVERSE_STMT(CXXUnresolvedConstructExpr, {
@@ -2361,25 +2460,25 @@
 DEF_TRAVERSE_STMT(CoroutineBodyStmt, {
   if (!getDerived().shouldVisitImplicitCode()) {
     TRY_TO_TRAVERSE_OR_ENQUEUE_STMT(S->getBody());
-    return true;
+    ShouldVisitChildren = false;
   }
 })
 DEF_TRAVERSE_STMT(CoreturnStmt, {
   if (!getDerived().shouldVisitImplicitCode()) {
     TRY_TO_TRAVERSE_OR_ENQUEUE_STMT(S->getOperand());
-    return true;
+    ShouldVisitChildren = false;
   }
 })
 DEF_TRAVERSE_STMT(CoawaitExpr, {
   if (!getDerived().shouldVisitImplicitCode()) {
     TRY_TO_TRAVERSE_OR_ENQUEUE_STMT(S->getOperand());
-    return true;
+    ShouldVisitChildren = false;
   }
 })
 DEF_TRAVERSE_STMT(CoyieldExpr, {
   if (!getDerived().shouldVisitImplicitCode()) {
     TRY_TO_TRAVERSE_OR_ENQUEUE_STMT(S->getOperand());
-    return true;
+    ShouldVisitChildren = false;
   }
 })
 
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to