================
@@ -437,7 +453,409 @@ void SemaHLSL::handleShaderAttr(Decl *D, const ParsedAttr 
&AL) {
     D->addAttr(NewAttr);
 }
 
+struct RegisterBindingFlags {
+  bool Resource = false;
+  bool Udt = false;
+  bool Other = false;
+  bool Basic = false;
+
+  bool Srv = false;
+  bool Uav = false;
+  bool Cbv = false;
+  bool Sampler = false;
+
+  bool ContainsNumeric = false;
+  bool DefaultGlobals = false;
+};
+
+bool isDeclaredWithinCOrTBuffer(const Decl *decl) {
+  if (!decl)
+    return false;
+
+  // Traverse up the parent contexts
+  const DeclContext *context = decl->getDeclContext();
+  while (context) {
+    if (isa<HLSLBufferDecl>(context)) {
+      return true;
+    }
+    context = context->getParent();
+  }
+
+  return false;
+}
+
+const CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *VD) {
+  const Type *Ty = VD->getType()->getPointeeOrArrayElementType();
+  if (!Ty)
+    llvm_unreachable("Resource class must have an element type.");
+
+  if (const BuiltinType *BTy = dyn_cast<BuiltinType>(Ty)) {
+    return nullptr;
+  }
+
+  const CXXRecordDecl *TheRecordDecl = Ty->getAsCXXRecordDecl();
+  if (!TheRecordDecl)
+    llvm_unreachable("Resource class should have a resource type 
declaration.");
+
+  if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(TheRecordDecl))
+    TheRecordDecl = TDecl->getSpecializedTemplate()->getTemplatedDecl();
+  TheRecordDecl = TheRecordDecl->getCanonicalDecl();
+  return TheRecordDecl;
+}
+
+const HLSLResourceAttr *
+getHLSLResourceAttrFromEitherDecl(VarDecl *VD,
+                                  HLSLBufferDecl *CBufferOrTBuffer) {
+
+  if (VD) {
+    const CXXRecordDecl *TheRecordDecl = getRecordDeclFromVarDecl(VD);
+    if (!TheRecordDecl)
+      return nullptr;
+    const auto *Attr = TheRecordDecl->getAttr<HLSLResourceAttr>();
+    return Attr;
+  } else if (CBufferOrTBuffer) {
+    const auto *Attr = CBufferOrTBuffer->getAttr<HLSLResourceAttr>();
+    return Attr;
+  }
+  llvm_unreachable("one of the two conditions should be true.");
+  return nullptr;
+}
+
+void traverseType(QualType T, RegisterBindingFlags &r) {
+  if (T->isIntegralOrEnumerationType() || T->isFloatingType()) {
+    r.ContainsNumeric = true;
+    return;
+  } else if (const RecordType *RT = T->getAs<RecordType>()) {
----------------
bob80905 wrote:

👍

https://github.com/llvm/llvm-project/pull/97103
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to