danix800 created this revision.
danix800 added reviewers: balazske, steakhal, aaron.ballman, shafik, martong.
Herald added a reviewer: a.sidorin.
Herald added a project: All.
danix800 requested review of this revision.
Herald added a project: clang.
Herald added a subscriber: cfe-commits.

Repeated friends are deduplicated when imported, but StructuralEquivalence 
checks friends by exact matching.
If `ToContext` is empty (not containing the class to be imported), the imported 
friends are deduplicated, any
further importing of the class would be rejected at the structure equivalence 
checking, i.e:

  struct foo { friend class X; friend class X; }; // FromContext

only one friend is imported, similar to the following:

  struct foo { friend class X; }; // ToContext

but when imported again, `struct foo` in FromContext is reported as not 
equivalent to `struct foo` in `ToContext`,
thus rejected.

The structural equivalence checking algorithm is improved by applying similar 
deduplication as Importer does.
Thus from StructuralEquivalence's point of view, the above two `RecordDecl`s 
are equivalent.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D157114

Files:
  clang/lib/AST/ASTImporter.cpp
  clang/lib/AST/ASTStructuralEquivalence.cpp
  clang/unittests/AST/ASTImporterTest.cpp
  clang/unittests/AST/StructuralEquivalenceTest.cpp

Index: clang/unittests/AST/StructuralEquivalenceTest.cpp
===================================================================
--- clang/unittests/AST/StructuralEquivalenceTest.cpp
+++ clang/unittests/AST/StructuralEquivalenceTest.cpp
@@ -833,7 +833,7 @@
   auto t = makeNamedDecls("struct foo { friend class X; };",
                           "struct foo { friend class X; friend class X; };",
                           Lang_CXX11);
-  EXPECT_FALSE(testStructuralMatch(t));
+  EXPECT_TRUE(testStructuralMatch(t));
 }
 
 TEST_F(StructuralEquivalenceRecordTest, SameFriendsDifferentOrder) {
Index: clang/unittests/AST/ASTImporterTest.cpp
===================================================================
--- clang/unittests/AST/ASTImporterTest.cpp
+++ clang/unittests/AST/ASTImporterTest.cpp
@@ -10,6 +10,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "clang/AST/ASTStructuralEquivalence.h"
 #include "clang/AST/RecordLayout.h"
 #include "clang/ASTMatchers/ASTMatchers.h"
 #include "llvm/ADT/StringMap.h"
@@ -4351,6 +4352,44 @@
   EXPECT_EQ(ToFriend2, ToImportedFriend2);
 }
 
+TEST_P(ASTImporterOptionSpecificTestBase, ImportRepeatedFriendDeclIntoEmptyDC) {
+  Decl *From, *To;
+  std::tie(From, To) = getImportedDecl(R"(
+      template <class T>
+      class A {
+      public:
+        template <class U> friend A<U> &f();
+        template <class U> friend A<U> &f();
+      };
+  )",
+                                       Lang_CXX17, "", Lang_CXX17, "A");
+
+  auto *FromFriend1 = FirstDeclMatcher<FriendDecl>().match(From, friendDecl());
+  auto *FromFriend2 = LastDeclMatcher<FriendDecl>().match(From, friendDecl());
+  auto *ToFriend1 = FirstDeclMatcher<FriendDecl>().match(To, friendDecl());
+  auto *ToFriend2 = LastDeclMatcher<FriendDecl>().match(To, friendDecl());
+
+  // Two different FriendDecls in From context.
+  EXPECT_TRUE(FromFriend1 != FromFriend2);
+  // Only one is imported into empty DC.
+  EXPECT_TRUE(ToFriend1 == ToFriend2);
+
+  // 'A' is imported into empty DC, keeping structure equivalence.
+  llvm::DenseSet<std::pair<Decl *, Decl *>> NonEquivalentDecls01;
+  llvm::DenseSet<std::pair<Decl *, Decl *>> NonEquivalentDecls10;
+  StructuralEquivalenceContext Ctx01(
+      From->getASTContext(), To->getASTContext(), NonEquivalentDecls01,
+      StructuralEquivalenceKind::Default, false, false);
+  StructuralEquivalenceContext Ctx10(
+      To->getASTContext(), From->getASTContext(), NonEquivalentDecls10,
+      StructuralEquivalenceKind::Default, false, false);
+
+  bool Eq01 = Ctx01.IsEquivalent(From, To);
+  bool Eq10 = Ctx10.IsEquivalent(To, From);
+  EXPECT_EQ(Eq01, Eq10);
+  EXPECT_TRUE(Eq01);
+}
+
 TEST_P(ASTImporterOptionSpecificTestBase, FriendFunInClassTemplate) {
   auto *Code = R"(
   template <class T>
Index: clang/lib/AST/ASTStructuralEquivalence.cpp
===================================================================
--- clang/lib/AST/ASTStructuralEquivalence.cpp
+++ clang/lib/AST/ASTStructuralEquivalence.cpp
@@ -1464,6 +1464,160 @@
   return IsStructurallyEquivalent(GetName(D1), GetName(D2));
 }
 
+static bool
+IsCXXRecordBaseStructurallyEquivalent(StructuralEquivalenceContext &Context,
+                                      RecordDecl *D1, RecordDecl *D2) {
+  auto *D1CXX = cast<CXXRecordDecl>(D1);
+  auto *D2CXX = cast<CXXRecordDecl>(D2);
+
+  if (D1CXX->getNumBases() != D2CXX->getNumBases()) {
+    if (Context.Complain) {
+      Context.Diag2(D2->getLocation(), Context.getApplicableDiagnostic(
+                                           diag::err_odr_tag_type_inconsistent))
+          << Context.ToCtx.getTypeDeclType(D2);
+      Context.Diag2(D2->getLocation(), diag::note_odr_number_of_bases)
+          << D2CXX->getNumBases();
+      Context.Diag1(D1->getLocation(), diag::note_odr_number_of_bases)
+          << D1CXX->getNumBases();
+    }
+    return false;
+  }
+
+  for (CXXRecordDecl::base_class_iterator Base1 = D1CXX->bases_begin(),
+                                          BaseEnd1 = D1CXX->bases_end(),
+                                          Base2 = D2CXX->bases_begin();
+       Base1 != BaseEnd1; ++Base1, ++Base2) {
+    if (!IsStructurallyEquivalent(Context, Base1->getType(),
+                                  Base2->getType())) {
+      if (Context.Complain) {
+        Context.Diag2(D2->getLocation(),
+                      Context.getApplicableDiagnostic(
+                          diag::err_odr_tag_type_inconsistent))
+            << Context.ToCtx.getTypeDeclType(D2);
+        Context.Diag2(Base2->getBeginLoc(), diag::note_odr_base)
+            << Base2->getType() << Base2->getSourceRange();
+        Context.Diag1(Base1->getBeginLoc(), diag::note_odr_base)
+            << Base1->getType() << Base1->getSourceRange();
+      }
+      return false;
+    }
+
+    // Check virtual vs. non-virtual inheritance mismatch.
+    if (Base1->isVirtual() != Base2->isVirtual()) {
+      if (Context.Complain) {
+        Context.Diag2(D2->getLocation(),
+                      Context.getApplicableDiagnostic(
+                          diag::err_odr_tag_type_inconsistent))
+            << Context.ToCtx.getTypeDeclType(D2);
+        Context.Diag2(Base2->getBeginLoc(), diag::note_odr_virtual_base)
+            << Base2->isVirtual() << Base2->getSourceRange();
+        Context.Diag1(Base1->getBeginLoc(), diag::note_odr_base)
+            << Base1->isVirtual() << Base1->getSourceRange();
+      }
+      return false;
+    }
+  }
+
+  return true;
+}
+
+using NonEquivalentDeclSet = llvm::DenseSet<std::pair<Decl *, Decl *>>;
+
+static bool IsEquivalentFriend(FriendDecl *F1, FriendDecl *F2,
+                               NonEquivalentDeclSet &NonEquivalentDecls) {
+  StructuralEquivalenceContext Ctx(
+      F1->getASTContext(), F2->getASTContext(), NonEquivalentDecls,
+      StructuralEquivalenceKind::Minimal, false, false);
+  if (F1->getFriendDecl() && F2->getFriendDecl())
+    return Ctx.IsEquivalent(F1->getFriendDecl(), F2->getFriendDecl());
+  if (F1->getFriendType() && F2->getFriendType())
+    return Ctx.IsEquivalent(F1->getFriendType()->getType(),
+                            F2->getFriendType()->getType());
+
+  return false;
+}
+
+static bool
+IsEquivalentToAnyExistingFriends(FriendDecl *F, ArrayRef<FriendDecl *> Friends,
+                                 NonEquivalentDeclSet &NonEquivalentDecls) {
+  for (FriendDecl *Other : Friends)
+    if (IsEquivalentFriend(F, Other, NonEquivalentDecls))
+      return true;
+
+  return false;
+}
+
+static SmallVector<FriendDecl *, 2> getDeduplicatedFriends(CXXRecordDecl *RD) {
+  NonEquivalentDeclSet NonEquivalentDecls;
+  SmallVector<FriendDecl *, 2> EquivalentFriends;
+
+  auto Friend = RD->friend_begin(), FriendEnd = RD->friend_end();
+  if (Friend == FriendEnd)
+    return EquivalentFriends;
+
+  EquivalentFriends.push_back(*Friend);
+  Friend = ++Friend;
+  while (Friend != FriendEnd) {
+    if (!IsEquivalentToAnyExistingFriends(*Friend, EquivalentFriends,
+                                          NonEquivalentDecls))
+      EquivalentFriends.push_back(*Friend);
+    Friend = ++Friend;
+  }
+
+  return EquivalentFriends;
+}
+
+static bool
+IsFriendInCXXRecordStructurallyEquivalent(StructuralEquivalenceContext &Context,
+                                          RecordDecl *D1, RecordDecl *D2) {
+  auto *D1CXX = cast<CXXRecordDecl>(D1);
+  auto *D2CXX = cast<CXXRecordDecl>(D2);
+
+  const auto &Friends1 = getDeduplicatedFriends(D1CXX);
+  const auto &Friends2 = getDeduplicatedFriends(D2CXX);
+
+  auto Friend2 = Friends2.begin(), Friend2End = Friends2.end();
+  for (auto Friend1 = Friends1.begin(), Friend1End = Friends1.end();
+       Friend1 != Friend1End; ++Friend1, ++Friend2) {
+    if (Friend2 == Friend2End) {
+      if (Context.Complain) {
+        Context.Diag2(D2->getLocation(),
+                      Context.getApplicableDiagnostic(
+                          diag::err_odr_tag_type_inconsistent))
+            << Context.ToCtx.getTypeDeclType(D2CXX);
+        Context.Diag1((*Friend1)->getFriendLoc(), diag::note_odr_friend);
+        Context.Diag2(D2->getLocation(), diag::note_odr_missing_friend);
+      }
+      return false;
+    }
+
+    if (!IsStructurallyEquivalent(Context, *Friend1, *Friend2)) {
+      if (Context.Complain) {
+        Context.Diag2(D2->getLocation(),
+                      Context.getApplicableDiagnostic(
+                          diag::err_odr_tag_type_inconsistent))
+            << Context.ToCtx.getTypeDeclType(D2CXX);
+        Context.Diag1((*Friend1)->getFriendLoc(), diag::note_odr_friend);
+        Context.Diag2((*Friend2)->getFriendLoc(), diag::note_odr_friend);
+      }
+      return false;
+    }
+  }
+
+  if (Friend2 != Friend2End) {
+    if (Context.Complain) {
+      Context.Diag2(D2->getLocation(), Context.getApplicableDiagnostic(
+                                           diag::err_odr_tag_type_inconsistent))
+          << Context.ToCtx.getTypeDeclType(D2);
+      Context.Diag2((*Friend2)->getFriendLoc(), diag::note_odr_friend);
+      Context.Diag1(D1->getLocation(), diag::note_odr_missing_friend);
+    }
+    return false;
+  }
+
+  return true;
+}
+
 /// Determine structural equivalence of two records.
 static bool IsStructurallyEquivalent(StructuralEquivalenceContext &Context,
                                      RecordDecl *D1, RecordDecl *D2) {
@@ -1562,98 +1716,11 @@
           return false;
       }
 
-      if (D1CXX->getNumBases() != D2CXX->getNumBases()) {
-        if (Context.Complain) {
-          Context.Diag2(D2->getLocation(),
-                        Context.getApplicableDiagnostic(
-                            diag::err_odr_tag_type_inconsistent))
-              << Context.ToCtx.getTypeDeclType(D2);
-          Context.Diag2(D2->getLocation(), diag::note_odr_number_of_bases)
-              << D2CXX->getNumBases();
-          Context.Diag1(D1->getLocation(), diag::note_odr_number_of_bases)
-              << D1CXX->getNumBases();
-        }
+      if (!IsCXXRecordBaseStructurallyEquivalent(Context, D1, D2))
         return false;
-      }
-
-      // Check the base classes.
-      for (CXXRecordDecl::base_class_iterator Base1 = D1CXX->bases_begin(),
-                                              BaseEnd1 = D1CXX->bases_end(),
-                                              Base2 = D2CXX->bases_begin();
-           Base1 != BaseEnd1; ++Base1, ++Base2) {
-        if (!IsStructurallyEquivalent(Context, Base1->getType(),
-                                      Base2->getType())) {
-          if (Context.Complain) {
-            Context.Diag2(D2->getLocation(),
-                          Context.getApplicableDiagnostic(
-                              diag::err_odr_tag_type_inconsistent))
-                << Context.ToCtx.getTypeDeclType(D2);
-            Context.Diag2(Base2->getBeginLoc(), diag::note_odr_base)
-                << Base2->getType() << Base2->getSourceRange();
-            Context.Diag1(Base1->getBeginLoc(), diag::note_odr_base)
-                << Base1->getType() << Base1->getSourceRange();
-          }
-          return false;
-        }
-
-        // Check virtual vs. non-virtual inheritance mismatch.
-        if (Base1->isVirtual() != Base2->isVirtual()) {
-          if (Context.Complain) {
-            Context.Diag2(D2->getLocation(),
-                          Context.getApplicableDiagnostic(
-                              diag::err_odr_tag_type_inconsistent))
-                << Context.ToCtx.getTypeDeclType(D2);
-            Context.Diag2(Base2->getBeginLoc(), diag::note_odr_virtual_base)
-                << Base2->isVirtual() << Base2->getSourceRange();
-            Context.Diag1(Base1->getBeginLoc(), diag::note_odr_base)
-                << Base1->isVirtual() << Base1->getSourceRange();
-          }
-          return false;
-        }
-      }
 
-      // Check the friends for consistency.
-      CXXRecordDecl::friend_iterator Friend2 = D2CXX->friend_begin(),
-                                     Friend2End = D2CXX->friend_end();
-      for (CXXRecordDecl::friend_iterator Friend1 = D1CXX->friend_begin(),
-                                          Friend1End = D1CXX->friend_end();
-           Friend1 != Friend1End; ++Friend1, ++Friend2) {
-        if (Friend2 == Friend2End) {
-          if (Context.Complain) {
-            Context.Diag2(D2->getLocation(),
-                          Context.getApplicableDiagnostic(
-                              diag::err_odr_tag_type_inconsistent))
-                << Context.ToCtx.getTypeDeclType(D2CXX);
-            Context.Diag1((*Friend1)->getFriendLoc(), diag::note_odr_friend);
-            Context.Diag2(D2->getLocation(), diag::note_odr_missing_friend);
-          }
-          return false;
-        }
-
-        if (!IsStructurallyEquivalent(Context, *Friend1, *Friend2)) {
-          if (Context.Complain) {
-            Context.Diag2(D2->getLocation(),
-                          Context.getApplicableDiagnostic(
-                              diag::err_odr_tag_type_inconsistent))
-                << Context.ToCtx.getTypeDeclType(D2CXX);
-            Context.Diag1((*Friend1)->getFriendLoc(), diag::note_odr_friend);
-            Context.Diag2((*Friend2)->getFriendLoc(), diag::note_odr_friend);
-          }
-          return false;
-        }
-      }
-
-      if (Friend2 != Friend2End) {
-        if (Context.Complain) {
-          Context.Diag2(D2->getLocation(),
-                        Context.getApplicableDiagnostic(
-                            diag::err_odr_tag_type_inconsistent))
-              << Context.ToCtx.getTypeDeclType(D2);
-          Context.Diag2((*Friend2)->getFriendLoc(), diag::note_odr_friend);
-          Context.Diag1(D1->getLocation(), diag::note_odr_missing_friend);
-        }
+      if (!IsFriendInCXXRecordStructurallyEquivalent(Context, D1, D2))
         return false;
-      }
     } else if (D1CXX->getNumBases() > 0) {
       if (Context.Complain) {
         Context.Diag2(D2->getLocation(),
@@ -2327,8 +2394,8 @@
     Decl *D1 = P.first;
     Decl *D2 = P.second;
 
-    bool Equivalent =
-        CheckCommonEquivalence(D1, D2) && CheckKindSpecificEquivalence(D1, D2);
+    bool Equivalent = (D1 == D2) || (CheckCommonEquivalence(D1, D2) &&
+                                     CheckKindSpecificEquivalence(D1, D2));
 
     if (!Equivalent) {
       // Note that these two declarations are not equivalent (and we already
Index: clang/lib/AST/ASTImporter.cpp
===================================================================
--- clang/lib/AST/ASTImporter.cpp
+++ clang/lib/AST/ASTImporter.cpp
@@ -6433,7 +6433,8 @@
 
   ToFunc->setAccess(D->getAccess());
   ToFunc->setLexicalDeclContext(LexicalDC);
-  LexicalDC->addDeclInternal(ToFunc);
+  if (D->getFriendObjectKind() == Decl::FOK_None)
+    LexicalDC->addDeclInternal(ToFunc);
 
   ASTImporterLookupTable *LT = Importer.SharedState->getLookupTable();
   if (LT && !OldParamDC.empty()) {
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to