dgoldman updated this revision to Diff 408928.
dgoldman marked an inline comment as done.
dgoldman added a comment.

Add tests and public accessors


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D119363/new/

https://reviews.llvm.org/D119363

Files:
  clang/include/clang/AST/ASTFwd.h
  clang/include/clang/AST/ASTTypeTraits.h
  clang/include/clang/AST/RecursiveASTVisitor.h
  clang/include/clang/AST/TypeLoc.h
  clang/lib/AST/ASTTypeTraits.cpp
  clang/lib/AST/ParentMapContext.cpp
  clang/unittests/AST/RecursiveASTVisitorTest.cpp

Index: clang/unittests/AST/RecursiveASTVisitorTest.cpp
===================================================================
--- clang/unittests/AST/RecursiveASTVisitorTest.cpp
+++ clang/unittests/AST/RecursiveASTVisitorTest.cpp
@@ -60,6 +60,12 @@
   EndTraverseEnum,
   StartTraverseTypedefType,
   EndTraverseTypedefType,
+  StartTraverseObjCInterface,
+  EndTraverseObjCInterface,
+  StartTraverseObjCProtocol,
+  EndTraverseObjCProtocol,
+  StartTraverseObjCProtocolLoc,
+  EndTraverseObjCProtocolLoc,
 };
 
 class CollectInterestingEvents
@@ -97,18 +103,43 @@
     return Ret;
   }
 
+  bool TraverseObjCInterfaceDecl(ObjCInterfaceDecl *ID) {
+    Events.push_back(VisitEvent::StartTraverseObjCInterface);
+    bool Ret = RecursiveASTVisitor::TraverseObjCInterfaceDecl(ID);
+    Events.push_back(VisitEvent::EndTraverseObjCInterface);
+
+    return Ret;
+  }
+
+  bool TraverseObjCProtocolDecl(ObjCProtocolDecl *PD) {
+    Events.push_back(VisitEvent::StartTraverseObjCProtocol);
+    bool Ret = RecursiveASTVisitor::TraverseObjCProtocolDecl(PD);
+    Events.push_back(VisitEvent::EndTraverseObjCProtocol);
+
+    return Ret;
+  }
+
+  bool TraverseObjCProtocolLoc(ObjCProtocolLoc ProtocolLoc) {
+    Events.push_back(VisitEvent::StartTraverseObjCProtocolLoc);
+    bool Ret = RecursiveASTVisitor::TraverseObjCProtocolLoc(ProtocolLoc);
+    Events.push_back(VisitEvent::EndTraverseObjCProtocolLoc);
+
+    return Ret;
+  }
+
   std::vector<VisitEvent> takeEvents() && { return std::move(Events); }
 
 private:
   std::vector<VisitEvent> Events;
 };
 
-std::vector<VisitEvent> collectEvents(llvm::StringRef Code) {
+std::vector<VisitEvent> collectEvents(llvm::StringRef Code,
+                                      const Twine &FileName = "input.cc") {
   CollectInterestingEvents Visitor;
   clang::tooling::runToolOnCode(
       std::make_unique<ProcessASTAction>(
           [&](clang::ASTContext &Ctx) { Visitor.TraverseAST(Ctx); }),
-      Code);
+      Code, FileName);
   return std::move(Visitor).takeEvents();
 }
 } // namespace
@@ -139,3 +170,28 @@
                           VisitEvent::EndTraverseTypedefType,
                           VisitEvent::EndTraverseEnum));
 }
+
+TEST(RecursiveASTVisitorTest, InterfaceDeclWithProtocols) {
+  // Check interface and its protocols are visited.
+  llvm::StringRef Code = R"cpp(
+  @protocol Foo
+  @end
+  @protocol Bar
+  @end
+
+  @interface SomeObject <Foo, Bar>
+  @end
+  )cpp";
+
+  EXPECT_THAT(collectEvents(Code, "input.m"),
+              ElementsAre(VisitEvent::StartTraverseObjCProtocol,
+                          VisitEvent::EndTraverseObjCProtocol,
+                          VisitEvent::StartTraverseObjCProtocol,
+                          VisitEvent::EndTraverseObjCProtocol,
+                          VisitEvent::StartTraverseObjCInterface,
+                          VisitEvent::StartTraverseObjCProtocolLoc,
+                          VisitEvent::EndTraverseObjCProtocolLoc,
+                          VisitEvent::StartTraverseObjCProtocolLoc,
+                          VisitEvent::EndTraverseObjCProtocolLoc,
+                          VisitEvent::EndTraverseObjCInterface));
+}
Index: clang/lib/AST/ParentMapContext.cpp
===================================================================
--- clang/lib/AST/ParentMapContext.cpp
+++ clang/lib/AST/ParentMapContext.cpp
@@ -330,6 +330,9 @@
 DynTypedNode createDynTypedNode(const NestedNameSpecifierLoc &Node) {
   return DynTypedNode::create(Node);
 }
+template <> DynTypedNode createDynTypedNode(const ObjCProtocolLoc &Node) {
+  return DynTypedNode::create(Node);
+}
 /// @}
 
 /// A \c RecursiveASTVisitor that builds a map from nodes to their
@@ -433,6 +436,12 @@
         AttrNode, AttrNode, [&] { return VisitorBase::TraverseAttr(AttrNode); },
         &Map.PointerParents);
   }
+  bool TraverseObjCProtocolLoc(ObjCProtocolLoc ProtocolLocNode) {
+    return TraverseNode(
+        ProtocolLocNode, DynTypedNode::create(ProtocolLocNode),
+        [&] { return VisitorBase::TraverseObjCProtocolLoc(ProtocolLocNode); },
+        &Map.OtherParents);
+  }
 
   // Using generic TraverseNode for Stmt would prevent data-recursion.
   bool dataTraverseStmtPre(Stmt *StmtNode) {
Index: clang/lib/AST/ASTTypeTraits.cpp
===================================================================
--- clang/lib/AST/ASTTypeTraits.cpp
+++ clang/lib/AST/ASTTypeTraits.cpp
@@ -16,6 +16,7 @@
 #include "clang/AST/ASTContext.h"
 #include "clang/AST/Attr.h"
 #include "clang/AST/DeclCXX.h"
+#include "clang/AST/DeclObjC.h"
 #include "clang/AST/NestedNameSpecifier.h"
 #include "clang/AST/OpenMPClause.h"
 #include "clang/AST/TypeLoc.h"
@@ -52,6 +53,7 @@
     {NKI_None, "Attr"},
 #define ATTR(A) {NKI_Attr, #A "Attr"},
 #include "clang/Basic/AttrList.inc"
+    {NKI_None, "ObjCProtocolLoc"},
 };
 
 bool ASTNodeKind::isBaseOf(ASTNodeKind Other, unsigned *Distance) const {
@@ -193,6 +195,8 @@
     QualType(T, 0).print(OS, PP);
   else if (const Attr *A = get<Attr>())
     A->printPretty(OS, PP);
+  else if (const ObjCProtocolLoc *P = get<ObjCProtocolLoc>())
+    P->getProtocol()->print(OS, PP);
   else
     OS << "Unable to print values of type " << NodeKind.asStringRef() << "\n";
 }
@@ -228,5 +232,7 @@
     return CBS->getSourceRange();
   if (const auto *A = get<Attr>())
     return A->getRange();
+  if (const ObjCProtocolLoc *P = get<ObjCProtocolLoc>())
+    return P->getSourceRange();
   return SourceRange();
 }
Index: clang/include/clang/AST/TypeLoc.h
===================================================================
--- clang/include/clang/AST/TypeLoc.h
+++ clang/include/clang/AST/TypeLoc.h
@@ -2607,6 +2607,25 @@
     : public InheritingConcreteTypeLoc<TypeSpecTypeLoc, DependentBitIntTypeLoc,
                                        DependentBitIntType> {};
 
+class ObjCProtocolLoc {
+  const ObjCProtocolDecl *Protocol = nullptr;
+  SourceLocation Loc = SourceLocation();
+
+public:
+  ObjCProtocolLoc(const ObjCProtocolDecl *protocol, SourceLocation loc)
+      : Protocol(protocol), Loc(loc) {}
+  const ObjCProtocolDecl *getProtocol() const { return Protocol; }
+  SourceLocation getLocation() const { return Loc; }
+
+  /// Get the full source range.
+  SourceRange getSourceRange() const LLVM_READONLY {
+    return SourceRange(Loc, Loc);
+  }
+
+  /// Evaluates true when this protocol loc is valid/non-empty.
+  explicit operator bool() const { return Protocol; }
+};
+
 } // namespace clang
 
 #endif // LLVM_CLANG_AST_TYPELOC_H
Index: clang/include/clang/AST/RecursiveASTVisitor.h
===================================================================
--- clang/include/clang/AST/RecursiveASTVisitor.h
+++ clang/include/clang/AST/RecursiveASTVisitor.h
@@ -324,6 +324,12 @@
   /// \returns false if the visitation was terminated early, true otherwise.
   bool TraverseConceptReference(const ConceptReference &C);
 
+  /// Recursively visit an Objective-C protocol reference with location
+  /// information.
+  ///
+  /// \returns false if the visitation was terminated early, true otherwise.
+  bool TraverseObjCProtocolLoc(ObjCProtocolLoc ProtocolLoc);
+
   // ---- Methods on Attrs ----
 
   // Visit an attribute.
@@ -1340,7 +1346,12 @@
 DEF_TRAVERSE_TYPELOC(PackExpansionType,
                      { TRY_TO(TraverseTypeLoc(TL.getPatternLoc())); })
 
-DEF_TRAVERSE_TYPELOC(ObjCTypeParamType, {})
+DEF_TRAVERSE_TYPELOC(ObjCTypeParamType, {
+  for (unsigned i = 0, n = TL.getNumProtocols(); i != n; ++i) {
+    ObjCProtocolLoc ProtocolLoc(TL.getProtocol(i), TL.getProtocolLoc(i));
+    TRY_TO(TraverseObjCProtocolLoc(ProtocolLoc));
+  }
+})
 
 DEF_TRAVERSE_TYPELOC(ObjCInterfaceType, {})
 
@@ -1351,6 +1362,10 @@
     TRY_TO(TraverseTypeLoc(TL.getBaseLoc()));
   for (unsigned i = 0, n = TL.getNumTypeArgs(); i != n; ++i)
     TRY_TO(TraverseTypeLoc(TL.getTypeArgTInfo(i)->getTypeLoc()));
+  for (unsigned i = 0, n = TL.getNumProtocols(); i != n; ++i) {
+    ObjCProtocolLoc ProtocolLoc(TL.getProtocol(i), TL.getProtocolLoc(i));
+    TRY_TO(TraverseObjCProtocolLoc(ProtocolLoc));
+  }
 })
 
 DEF_TRAVERSE_TYPELOC(ObjCObjectPointerType,
@@ -1541,12 +1556,16 @@
 DEF_TRAVERSE_DECL(ObjCCompatibleAliasDecl, {// FIXME: implement
                                            })
 
-DEF_TRAVERSE_DECL(ObjCCategoryDecl, {// FIXME: implement
+DEF_TRAVERSE_DECL(ObjCCategoryDecl, {
   if (ObjCTypeParamList *typeParamList = D->getTypeParamList()) {
     for (auto typeParam : *typeParamList) {
       TRY_TO(TraverseObjCTypeParamDecl(typeParam));
     }
   }
+  for (auto It : llvm::zip(D->protocols(), D->protocol_locs())) {
+    ObjCProtocolLoc ProtocolLoc(std::get<0>(It), std::get<1>(It));
+    TRY_TO(TraverseObjCProtocolLoc(ProtocolLoc));
+  }
 })
 
 DEF_TRAVERSE_DECL(ObjCCategoryImplDecl, {// FIXME: implement
@@ -1555,7 +1574,7 @@
 DEF_TRAVERSE_DECL(ObjCImplementationDecl, {// FIXME: implement
                                           })
 
-DEF_TRAVERSE_DECL(ObjCInterfaceDecl, {// FIXME: implement
+DEF_TRAVERSE_DECL(ObjCInterfaceDecl, {
   if (ObjCTypeParamList *typeParamList = D->getTypeParamListAsWritten()) {
     for (auto typeParam : *typeParamList) {
       TRY_TO(TraverseObjCTypeParamDecl(typeParam));
@@ -1565,10 +1584,22 @@
   if (TypeSourceInfo *superTInfo = D->getSuperClassTInfo()) {
     TRY_TO(TraverseTypeLoc(superTInfo->getTypeLoc()));
   }
+  if (D->isThisDeclarationADefinition()) {
+    for (auto It : llvm::zip(D->protocols(), D->protocol_locs())) {
+      ObjCProtocolLoc ProtocolLoc(std::get<0>(It), std::get<1>(It));
+      TRY_TO(TraverseObjCProtocolLoc(ProtocolLoc));
+    }
+  }
 })
 
-DEF_TRAVERSE_DECL(ObjCProtocolDecl, {// FIXME: implement
-                                    })
+DEF_TRAVERSE_DECL(ObjCProtocolDecl, {
+  if (D->isThisDeclarationADefinition()) {
+    for (auto It : llvm::zip(D->protocols(), D->protocol_locs())) {
+      ObjCProtocolLoc ProtocolLoc(std::get<0>(It), std::get<1>(It));
+      TRY_TO(TraverseObjCProtocolLoc(ProtocolLoc));
+    }
+  }
+})
 
 DEF_TRAVERSE_DECL(ObjCMethodDecl, {
   if (D->getReturnTypeSourceInfo()) {
@@ -2409,6 +2440,12 @@
   return true;
 }
 
+template <typename Derived>
+bool RecursiveASTVisitor<Derived>::TraverseObjCProtocolLoc(
+    ObjCProtocolLoc ProtocolLoc) {
+  return true;
+}
+
 // If shouldVisitImplicitCode() returns false, this method traverses only the
 // syntactic form of InitListExpr.
 // If shouldVisitImplicitCode() return true, this method is called once for
Index: clang/include/clang/AST/ASTTypeTraits.h
===================================================================
--- clang/include/clang/AST/ASTTypeTraits.h
+++ clang/include/clang/AST/ASTTypeTraits.h
@@ -160,6 +160,7 @@
     NKI_Attr,
 #define ATTR(A) NKI_##A##Attr,
 #include "clang/Basic/AttrList.inc"
+    NKI_ObjCProtocolLoc,
     NKI_NumberOfKinds
   };
 
@@ -213,6 +214,7 @@
 KIND_TO_KIND_ID(Type)
 KIND_TO_KIND_ID(OMPClause)
 KIND_TO_KIND_ID(Attr)
+KIND_TO_KIND_ID(ObjCProtocolLoc)
 KIND_TO_KIND_ID(CXXBaseSpecifier)
 #define DECL(DERIVED, BASE) KIND_TO_KIND_ID(DERIVED##Decl)
 #include "clang/AST/DeclNodes.inc"
@@ -499,7 +501,7 @@
   /// have storage or unique pointers and thus need to be stored by value.
   llvm::AlignedCharArrayUnion<const void *, TemplateArgument,
                               TemplateArgumentLoc, NestedNameSpecifierLoc,
-                              QualType, TypeLoc>
+                              QualType, TypeLoc, ObjCProtocolLoc>
       Storage;
 };
 
@@ -570,6 +572,10 @@
 struct DynTypedNode::BaseConverter<CXXBaseSpecifier, void>
     : public PtrConverter<CXXBaseSpecifier> {};
 
+template <>
+struct DynTypedNode::BaseConverter<ObjCProtocolLoc, void>
+    : public ValueConverter<ObjCProtocolLoc> {};
+
 // The only operation we allow on unsupported types is \c get.
 // This allows to conveniently use \c DynTypedNode when having an arbitrary
 // AST node that is not supported, but prevents misuse - a user cannot create
Index: clang/include/clang/AST/ASTFwd.h
===================================================================
--- clang/include/clang/AST/ASTFwd.h
+++ clang/include/clang/AST/ASTFwd.h
@@ -33,6 +33,7 @@
 class Attr;
 #define ATTR(A) class A##Attr;
 #include "clang/Basic/AttrList.inc"
+class ObjCProtocolLoc;
 
 } // end namespace clang
 
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to