================
@@ -2361,6 +2361,121 @@ static bool DiagnoseHLSLRegisterAttribute(Sema &S,
SourceLocation &ArgLoc,
return ValidateMultipleRegisterAnnotations(S, D, RegType);
}
+// return false if the slot count exceeds the limit, true otherwise
+static bool AccumulateHLSLResourceSlots(QualType Ty, uint64_t &StartSlot,
+ const uint64_t &Limit,
+ const ResourceClass ResClass,
+ ASTContext &Ctx,
+ uint64_t ArrayCount = 1) {
+ Ty = Ty.getCanonicalType();
+ const Type *T = Ty.getTypePtr();
+
+ // Early exit if already overflowed
+ if (StartSlot > Limit)
+ return false;
+
+ // Case 1: array type
+ if (const auto *AT = dyn_cast<ArrayType>(T)) {
+ uint64_t Count = 1;
+
+ if (const auto *CAT = dyn_cast<ConstantArrayType>(AT))
+ Count = CAT->getSize().getZExtValue();
+
+ QualType ElemTy = AT->getElementType();
+ return AccumulateHLSLResourceSlots(ElemTy, StartSlot, Limit, ResClass, Ctx,
+ ArrayCount * Count);
+ }
+
+ // Case 2: resource leaf
+ if (auto ResTy = dyn_cast<HLSLAttributedResourceType>(T)) {
+ // First ensure this resource counts towards the corresponding
+ // register type limit.
+ if (ResTy->getAttrs().ResourceClass != ResClass)
+ return true;
+
+ // Validate highest slot used
+ uint64_t EndSlot = StartSlot + ArrayCount - 1;
+ if (EndSlot > Limit)
+ return false;
+
+ // Advance SlotCount past the consumed range
+ StartSlot = EndSlot + 1;
+ return true;
+ }
+
+ // Case 3: struct / record
+ if (const auto *RT = dyn_cast<RecordType>(T)) {
+ const RecordDecl *RD = RT->getDecl();
+
+ if (const auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+ for (const CXXBaseSpecifier &Base : CXXRD->bases()) {
+ if (!AccumulateHLSLResourceSlots(Base.getType(), StartSlot, Limit,
+ ResClass, Ctx, ArrayCount))
+ return false;
+ }
+ }
+
+ for (const FieldDecl *Field : RD->fields()) {
+ if (!AccumulateHLSLResourceSlots(Field->getType(), StartSlot, Limit,
+ ResClass, Ctx, ArrayCount))
+ return false;
+ }
+
+ return true;
+ }
+
+ // Case 4: everything else
+ return true;
+}
+
+// return true if there is something invalid, false otherwise
+static bool ValidateRegisterNumber(const StringRef SlotNumStr, Decl *TheDecl,
+ ASTContext &Ctx, RegisterType RegTy,
+ unsigned &Result) {
+ uint64_t SlotNum;
+ if (SlotNumStr.getAsInteger(10, SlotNum))
+ return true;
+
+ const uint64_t Limit = UINT32_MAX;
+ if (SlotNum > Limit)
+ return true;
+
+ // after verifying the number doesn't exceed uint32max, we don't need
+ // to look further into c or i register types
+ if (RegTy == RegisterType::C || RegTy == RegisterType::I)
+ return false;
+
+ if (VarDecl *VD = dyn_cast<VarDecl>(TheDecl)) {
+ uint64_t BaseSlot = SlotNum;
+
+ if (!AccumulateHLSLResourceSlots(VD->getType(), SlotNum, Limit,
+ getResourceClass(RegTy), Ctx))
+ return true;
+
+ // After AccumulateHLSLResourceSlots runs, SlotNum is now
+ // the first free slot; last used was SlotNum - 1
+ if (BaseSlot > Limit)
+ return true;
+
+ SlotNumStr.getAsInteger(10, Result);
+ return false;
+ }
+ // handle the cbuffer case
+ if (dyn_cast<HLSLBufferDecl>(TheDecl)) {
+ // resources cannot be put within a cbuffer, so no need
+ // to analyze the structure since the register number
+ // won't be pushed any higher.
+ if (SlotNum > Limit)
----------------
tex3d wrote:
Ok.
Regarding the test, the additions look good, though I'm wondering if there's
another place that catches these cases too:
```
cbuffer CB : register(t0) { ... }
tbuffer TB : register(b0) { ... }
```
They are both `HLSLBufferDecl`, but the binding type is mismatched. I don't see
this in the test, but perhaps there's another test for this already.
https://github.com/llvm/llvm-project/pull/174028
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits