================
@@ -2592,6 +2595,281 @@ static RValue 
EmitHipStdParUnsupportedBuiltin(CodeGenFunction *CGF,
   return RValue::get(CGF->Builder.CreateCall(UBF, Args));
 }
 
+namespace {
+
+// PaddingClearer is a utility class that clears padding bits in a
+// c++ type. It traverses the type recursively, collecting occupied
+// bit intervals, and then compute the padding intervals.
+// In the end, it clears the padding bits by writing zeros
+// to the padding intervals bytes-by-bytes. If a byte only contains
+// some padding bits, it writes zeros to only those bits. This is
+// the case for bit-fields.
+struct PaddingClearer {
+  PaddingClearer(CodeGenFunction &F)
+      : CGF(F), CharWidth(CGF.getContext().getCharWidth()) {}
+
+  void run(Value *Ptr, QualType Ty) {
+    OccuppiedIntervals.clear();
+    Queue.clear();
+
+    Queue.push_back(Data{0, Ty, true});
+    while (!Queue.empty()) {
+      auto Current = Queue.back();
+      Queue.pop_back();
+      Visit(Current);
+    }
+
+    MergeOccuppiedIntervals();
+    auto PaddingIntervals =
+        GetPaddingIntervals(CGF.getContext().getTypeSize(Ty));
+    for (const auto &Interval : PaddingIntervals) {
+      ClearPadding(Ptr, Interval);
+    }
+  }
+
+private:
+  struct BitInterval {
+    // [First, Last)
+    uint64_t First;
+    uint64_t Last;
+  };
+
+  struct Data {
+    uint64_t StartBitOffset;
+    QualType Ty;
+    bool VisitVirtualBase;
+  };
+
+  void Visit(Data const &D) {
+    if (auto *AT = dyn_cast<ConstantArrayType>(D.Ty)) {
+      VisitArray(AT, D.StartBitOffset);
+      return;
+    }
+
+    if (auto *Record = D.Ty->getAsCXXRecordDecl()) {
+      VisitStruct(Record, D.StartBitOffset, D.VisitVirtualBase);
+      return;
+    }
+
+    if (D.Ty->isAtomicType()) {
+      auto Unwrapped = D;
+      Unwrapped.Ty = D.Ty.getAtomicUnqualifiedType();
+      Queue.push_back(Unwrapped);
+      return;
+    }
+
+    if (const auto *Complex = D.Ty->getAs<ComplexType>()) {
+      VisitComplex(Complex, D.StartBitOffset);
+      return;
+    }
+
+    auto *Type = CGF.ConvertTypeForMem(D.Ty);
+    auto SizeBit = CGF.CGM.getModule()
+                       .getDataLayout()
+                       .getTypeSizeInBits(Type)
+                       .getKnownMinValue();
+    OccuppiedIntervals.push_back(
+        BitInterval{D.StartBitOffset, D.StartBitOffset + SizeBit});
+  }
+
+  void VisitArray(const ConstantArrayType *AT, uint64_t StartBitOffset) {
+    for (uint64_t ArrIndex = 0; ArrIndex < AT->getSize().getLimitedValue();
+         ++ArrIndex) {
+
+      QualType ElementQualType = AT->getElementType();
+      auto ElementSize = CGF.getContext().getTypeSizeInChars(ElementQualType);
+      auto ElementAlign = 
CGF.getContext().getTypeAlignInChars(ElementQualType);
+      auto Offset = ElementSize.alignTo(ElementAlign);
+
+      Queue.push_back(
+          Data{StartBitOffset + ArrIndex * Offset.getQuantity() * CharWidth,
+               ElementQualType, /*VisitVirtualBase*/ true});
+    }
+  }
+
+  void VisitStruct(const CXXRecordDecl *R, uint64_t StartBitOffset,
+                   bool VisitVirtualBase) {
+    const auto &DL = CGF.CGM.getModule().getDataLayout();
+
+    const ASTRecordLayout &ASTLayout = CGF.getContext().getASTRecordLayout(R);
+    if (ASTLayout.hasOwnVFPtr()) {
+      OccuppiedIntervals.push_back(BitInterval{
+          StartBitOffset, StartBitOffset + DL.getPointerSizeInBits()});
+    }
+
+    const auto VisitBase = [&ASTLayout, StartBitOffset, this](
+                               const CXXBaseSpecifier &Base, auto GetOffset) {
+      auto *BaseRecord = Base.getType()->getAsCXXRecordDecl();
+      if (!BaseRecord) {
+        return;
+      }
+      auto BaseOffset =
+          std::invoke(GetOffset, ASTLayout, BaseRecord).getQuantity();
+
+      Queue.push_back(Data{StartBitOffset + BaseOffset * CharWidth,
+                           Base.getType(), /*VisitVirtualBase*/ false});
+    };
+
+    for (auto Base : R->bases()) {
+      if (!Base.isVirtual()) {
+        VisitBase(Base, &ASTRecordLayout::getBaseClassOffset);
+      }
+    }
+
+    if (VisitVirtualBase) {
+      for (auto VBase : R->vbases()) {
+        VisitBase(VBase, &ASTRecordLayout::getVBaseClassOffset);
+      }
+    }
+
+    for (auto *Field : R->fields()) {
+      auto FieldOffset = ASTLayout.getFieldOffset(Field->getFieldIndex());
+      if (Field->isBitField()) {
+        OccuppiedIntervals.push_back(BitInterval{
+            StartBitOffset + FieldOffset,
+            StartBitOffset + FieldOffset + Field->getBitWidthValue()});
+      } else {
+        Queue.push_back(Data{StartBitOffset + FieldOffset, Field->getType(),
+                             /*VisitVirtualBase*/ true});
+      }
+    }
+  }
+
+  void VisitComplex(const ComplexType *CT, uint64_t StartBitOffset) {
+    QualType ElementQualType = CT->getElementType();
+    auto ElementSize = CGF.getContext().getTypeSizeInChars(ElementQualType);
+    auto ElementAlign = CGF.getContext().getTypeAlignInChars(ElementQualType);
+    auto ImgOffset = ElementSize.alignTo(ElementAlign);
+
+    Queue.push_back(
+        Data{StartBitOffset, ElementQualType, /*VisitVirtualBase*/ true});
+    Queue.push_back(Data{StartBitOffset + ImgOffset.getQuantity() * CharWidth,
+                         ElementQualType, /*VisitVirtualBase*/ true});
+  }
+
+  void MergeOccuppiedIntervals() {
+    std::sort(OccuppiedIntervals.begin(), OccuppiedIntervals.end(),
+              [](const BitInterval &lhs, const BitInterval &rhs) {
+                return std::tie(lhs.First, lhs.Last) <
+                       std::tie(rhs.First, rhs.Last);
+              });
+
+    std::vector<BitInterval> Merged;
+    Merged.reserve(OccuppiedIntervals.size());
+
+    for (const BitInterval &NextInterval : OccuppiedIntervals) {
+      if (Merged.empty()) {
+        Merged.push_back(NextInterval);
+        continue;
+      }
+      auto &LastInterval = Merged.back();
+
+      if (NextInterval.First > LastInterval.Last) {
+        Merged.push_back(NextInterval);
+      } else {
+        LastInterval.Last = std::max(LastInterval.Last, NextInterval.Last);
+      }
+    }
+
+    OccuppiedIntervals = Merged;
+  }
+
+  std::vector<BitInterval> GetPaddingIntervals(uint64_t SizeInBits) const {
+    std::vector<BitInterval> Results;
+    if (OccuppiedIntervals.size() == 1 &&
+        OccuppiedIntervals.front().First == 0 &&
+        OccuppiedIntervals.end()->Last == SizeInBits) {
+      return Results;
+    }
+    Results.reserve(OccuppiedIntervals.size() + 1);
+    uint64_t CurrentPos = 0;
+    for (const BitInterval &OccupiedInterval : OccuppiedIntervals) {
+      if (OccupiedInterval.First > CurrentPos) {
+        Results.push_back(BitInterval{CurrentPos, OccupiedInterval.First});
+      }
+      CurrentPos = OccupiedInterval.Last;
+    }
+    if (SizeInBits > CurrentPos) {
+      Results.push_back(BitInterval{CurrentPos, SizeInBits});
+    }
+    return Results;
+  }
+
+  void ClearPadding(Value *Ptr, const BitInterval &PaddingInterval) {
+    auto *I8Ptr = CGF.Builder.CreateBitCast(Ptr, CGF.Int8PtrTy);
+    auto *Zero = ConstantInt::get(CGF.Int8Ty, 0);
+
+    // Calculate byte indices and bit positions
+    auto StartByte = PaddingInterval.First / CharWidth;
+    auto StartBit = PaddingInterval.First % CharWidth;
+    auto EndByte = PaddingInterval.Last / CharWidth;
+    auto EndBit = PaddingInterval.Last % CharWidth;
+
+    if (StartByte == EndByte) {
+      // Interval is within a single byte
+      auto *Index = ConstantInt::get(CGF.IntTy, StartByte);
+      auto *Element = CGF.Builder.CreateGEP(CGF.Int8Ty, I8Ptr, Index);
+      Address ElementAddr(Element, CGF.Int8Ty, CharUnits::One());
+
+      auto *Value = CGF.Builder.CreateLoad(ElementAddr);
+
+      // Create mask to clear bits within the byte
+      uint8_t mask = ((1 << EndBit) - 1) & ~((1 << StartBit) - 1);
+      auto *MaskValue = ConstantInt::get(CGF.Int8Ty, mask);
+      auto *NewValue = CGF.Builder.CreateAnd(Value, MaskValue);
+
+      CGF.Builder.CreateStore(NewValue, ElementAddr);
+    } else {
+      // Handle the start byte
+      if (StartBit != 0) {
+        auto *Index = ConstantInt::get(CGF.IntTy, StartByte);
+        auto *Element = CGF.Builder.CreateGEP(CGF.Int8Ty, I8Ptr, Index);
+        Address ElementAddr(Element, CGF.Int8Ty, CharUnits::One());
+
+        auto *Value = CGF.Builder.CreateLoad(ElementAddr);
+
+        uint8_t startMask = ((1 << (CharWidth - StartBit)) - 1) << StartBit;
+        auto *MaskValue = ConstantInt::get(CGF.Int8Ty, ~startMask);
+        auto *NewValue = CGF.Builder.CreateAnd(Value, MaskValue);
+
+        CGF.Builder.CreateStore(NewValue, ElementAddr);
+        ++StartByte;
+      }
+
+      // Handle full bytes in the middle
+      for (auto Offset = StartByte; Offset < EndByte; ++Offset) {
+        auto *Index = ConstantInt::get(CGF.IntTy, Offset);
+        auto *Element = CGF.Builder.CreateGEP(CGF.Int8Ty, I8Ptr, Index);
+        Address ElementAddr(Element, CGF.Int8Ty, CharUnits::One());
+
+        CGF.Builder.CreateStore(Zero, ElementAddr);
+      }
+
+      // Handle the end byte
+      if (EndBit != 0) {
+        auto *Index = ConstantInt::get(CGF.IntTy, EndByte);
+        auto *Element = CGF.Builder.CreateGEP(CGF.Int8Ty, I8Ptr, Index);
+        Address ElementAddr(Element, CGF.Int8Ty, CharUnits::One());
+
+        auto *Value = CGF.Builder.CreateLoad(ElementAddr);
+
+        uint8_t endMask = (1 << EndBit) - 1;
+        auto *MaskValue = ConstantInt::get(CGF.Int8Ty, endMask);
+        auto *NewValue = CGF.Builder.CreateAnd(Value, MaskValue);
+
+        CGF.Builder.CreateStore(NewValue, ElementAddr);
+      }
+    }
+  }
+
+  CodeGenFunction &CGF;
+  const uint64_t CharWidth;
+  std::deque<Data> Queue;
----------------
huixie90 wrote:

Very good suggestion! Indeed in practice, the `T` in `atomic<T>` is usually a 
very flat struct or even builtin types so the vector is indeed usually small

https://github.com/llvm/llvm-project/pull/75371
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to