================
@@ -9003,6 +9037,218 @@ static void 
HandleHLSLParamModifierAttr(TypeProcessingState &State,
   }
 }
 
+static CountAttributedType::DynamicCountPointerKind
+getCountAttrKind(bool CountInBytes, bool OrNull) {
+  if (CountInBytes)
+    return OrNull ? CountAttributedType::SizedByOrNull
+                  : CountAttributedType::SizedBy;
+  return OrNull ? CountAttributedType::CountedByOrNull
+                : CountAttributedType::CountedBy;
+}
+
+enum class CountedByInvalidPointeeTypeKind {
+  INCOMPLETE,
+  SIZELESS,
+  FUNCTION,
+  FLEXIBLE_ARRAY_MEMBER,
+  VALID,
+};
+
+/// Calculate the pointer nesting level for counted_by attribute validation.
+/// Counts the number of pointer/array/function declarator chunks before the
+/// specified chunk index.
+/// \param state The type processing state
+/// \param chunkIndex The index of the current declarator chunk
+/// \return The number of pointer/array/function chunks before chunkIndex
+static unsigned getPointerNestLevel(TypeProcessingState &state,
+                                    unsigned chunkIndex) {
+  unsigned pointerNestLevel = 0;
+  if (chunkIndex > 0) {
+    const auto &stateDeclarator = state.getDeclarator();
+    assert(chunkIndex <= stateDeclarator.getNumTypeObjects());
+    // DeclChunks are ordered identifier out. Index 0 is the outer most type
+    // object. Find outer pointer, array or function.
+    for (unsigned i = 0; i < chunkIndex; ++i) {
+      auto TypeObject = stateDeclarator.getTypeObject(i);
+      switch (TypeObject.Kind) {
+      case DeclaratorChunk::Function:
+      case DeclaratorChunk::Array:
+      case DeclaratorChunk::Pointer:
+        pointerNestLevel++;
+        break;
+      default:
+        break;
+      }
+    }
+  }
+  return pointerNestLevel;
+}
+
+static bool validateCountedByAttrType(Sema &S, QualType Ty,
+                                      ParsedAttr::Kind AttrKind,
+                                      SourceLocation AttrLoc,
+                                      unsigned pointerNestLevel,
+                                      bool &CountInBytes, bool &OrNull) {
+  switch (AttrKind) {
+  case ParsedAttr::AT_CountedBy:
+    CountInBytes = false;
+    OrNull = false;
+    break;
+  case ParsedAttr::AT_CountedByOrNull:
+    CountInBytes = false;
+    OrNull = true;
+    break;
+  case ParsedAttr::AT_SizedBy:
+    CountInBytes = true;
+    OrNull = false;
+    break;
+  case ParsedAttr::AT_SizedByOrNull:
+    CountInBytes = true;
+    OrNull = true;
+    break;
+  default:
+    llvm_unreachable("unexpected counted_by family attribute");
+  }
+
+  unsigned Kind = getCountAttrKind(CountInBytes, OrNull);
+
+  if (Ty->isArrayType() && (CountInBytes || OrNull)) {
+    S.Diag(AttrLoc, diag::err_count_attr_not_on_ptr_or_flexible_array_member)
+        << Kind << /* suggest counted_by */ 1;
+    return false;
+  }
+  if (!Ty->isArrayType() && !Ty->isPointerType()) {
+    S.Diag(AttrLoc, diag::err_count_attr_not_on_ptr_or_flexible_array_member)
+        << Kind << /* do not suggest counted_by */ 0;
+    return false;
+  }
+
+  // Validate pointee type
+  CountedByInvalidPointeeTypeKind InvalidTypeKind =
+      CountedByInvalidPointeeTypeKind::VALID;
+  QualType PointeeTy;
+  int SelectPtrOrArr = 0;
+  if (Ty->isPointerType()) {
+    PointeeTy = Ty->getPointeeType();
+    SelectPtrOrArr = 0;
+  } else {
+    assert(Ty->isArrayType());
+    const ArrayType *AT = S.getASTContext().getAsArrayType(Ty);
+    PointeeTy = AT->getElementType();
+    SelectPtrOrArr = 1;
+  }
+
+  bool ShouldWarn = false;
+  if (PointeeTy->isAlwaysIncompleteType() && !CountInBytes) {
+    bool IsVoidPtr = PointeeTy->isVoidType();
+    if (IsVoidPtr) {
+      S.Diag(AttrLoc, diag::ext_gnu_counted_by_void_ptr) << Kind;
+      S.Diag(AttrLoc, diag::note_gnu_counted_by_void_ptr_use_sized_by) << Kind;
+      assert(InvalidTypeKind == CountedByInvalidPointeeTypeKind::VALID);
+    } else {
+      InvalidTypeKind = CountedByInvalidPointeeTypeKind::INCOMPLETE;
+    }
+  } else if (PointeeTy->isSizelessType()) {
+    InvalidTypeKind = CountedByInvalidPointeeTypeKind::SIZELESS;
+  } else if (PointeeTy->isFunctionType()) {
+    InvalidTypeKind = CountedByInvalidPointeeTypeKind::FUNCTION;
+  } else if (PointeeTy->isStructureTypeWithFlexibleArrayMember()) {
+    if (Ty->isArrayType() && !S.getLangOpts().BoundsSafety) {
+      ShouldWarn = true;
+    }
+    InvalidTypeKind = CountedByInvalidPointeeTypeKind::FLEXIBLE_ARRAY_MEMBER;
+  }
+
+  if (InvalidTypeKind != CountedByInvalidPointeeTypeKind::VALID) {
+    unsigned DiagID = ShouldWarn
+                          ? diag::warn_counted_by_attr_elt_type_unknown_size
+                          : diag::err_counted_by_attr_pointee_unknown_size;
+    S.Diag(AttrLoc, DiagID)
+        << SelectPtrOrArr << PointeeTy << (int)InvalidTypeKind
+        << (ShouldWarn ? 1 : 0) << Kind;
+    return false;
+  }
+
+  if (pointerNestLevel > 0) {
+    S.Diag(AttrLoc, diag::err_counted_by_on_nested_pointer) << Kind;
+    return false;
+  }
+
+  return true;
+}
+
+static void HandleCountedByAttrOnType(TypeProcessingState &State,
+                                      QualType &CurType, ParsedAttr &Attr) {
+  Sema &S = State.getSema();
+
+  // This attribute is only supported in C.
+  // FIXME: we should implement checkCommonAttributeFeatures() in SemaAttr.cpp
+  // such that it handles type attributes, and then call that from
+  // processTypeAttrs() instead of one-off checks like this.
+  if (!Attr.diagnoseLangOpts(S)) {
+    Attr.setInvalid();
+    return;
+  }
+
+  auto *CountExpr = Attr.getArgAsExpr(0);
+  if (!CountExpr)
+    return;
+
+  // This is a mechanism to prevent nested count pointer types in the contexts
+  // where late parsing isn't allowed: currently that is any context other than
+  // struct fields. In the context where late parsing is allowed, the level
+  // check will be done once the whole context is constructed.
+  unsigned chunkIndex = State.getCurrentChunkIndex();
+  unsigned pointerNestLevel = 0;
+
+  // Only calculate pointer nest level if we're processing a declarator chunk.
+  // For DeclSpec attributes, the declarator hasn't been constructed yet.
+  if (chunkIndex > 0) {
----------------
zmodem wrote:

`getPointerNestLevel` also has the `chunkIndex > 0` check. I guess at least one 
of them is redundant.

https://github.com/llvm/llvm-project/pull/179612
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to