================
@@ -437,7 +460,406 @@ void SemaHLSL::handleShaderAttr(Decl *D, const ParsedAttr 
&AL) {
     D->addAttr(NewAttr);
 }
 
+struct register_binding_flags {
+  bool resource = false;
+  bool udt = false;
+  bool other = false;
+  bool basic = false;
+
+  bool srv = false;
+  bool uav = false;
+  bool cbv = false;
+  bool sampler = false;
+
+  bool contains_numeric = false;
+  bool default_globals = false;
+};
+
+bool isDeclaredWithinCOrTBuffer(const Decl *decl) {
+  if (!decl)
+    return false;
+
+  // Traverse up the parent contexts
+  const DeclContext *context = decl->getDeclContext();
+  while (context) {
+    if (isa<HLSLBufferDecl>(context)) {
+      return true;
+    }
+    context = context->getParent();
+  }
+
+  return false;
+}
+
+const CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *SamplerUAVOrSRV) {
+  const Type *Ty = SamplerUAVOrSRV->getType()->getPointeeOrArrayElementType();
+  if (!Ty)
+    llvm_unreachable("Resource class must have an element type.");
+
+  if (const BuiltinType *BTy = dyn_cast<BuiltinType>(Ty)) {
+    return nullptr;
+  }
+
+  const CXXRecordDecl *TheRecordDecl = Ty->getAsCXXRecordDecl();
+  if (!TheRecordDecl)
+    llvm_unreachable("Resource class should have a resource type 
declaration.");
+
+  if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(TheRecordDecl))
+    TheRecordDecl = TDecl->getSpecializedTemplate()->getTemplatedDecl();
+  TheRecordDecl = TheRecordDecl->getCanonicalDecl();
+  return TheRecordDecl;
+}
+
+const HLSLResourceAttr *
+getHLSLResourceAttrFromEitherDecl(VarDecl *SamplerUAVOrSRV,
+                                  HLSLBufferDecl *CBufferOrTBuffer) {
+
+  if (SamplerUAVOrSRV) {
+    const CXXRecordDecl *TheRecordDecl =
+        getRecordDeclFromVarDecl(SamplerUAVOrSRV);
+    if (!TheRecordDecl)
+      return nullptr;
+    const auto *Attr = TheRecordDecl->getAttr<HLSLResourceAttr>();
+    return Attr;
+  } else if (CBufferOrTBuffer) {
+    const auto *Attr = CBufferOrTBuffer->getAttr<HLSLResourceAttr>();
+    return Attr;
+  }
+  llvm_unreachable("one of the two conditions should be true.");
+  return nullptr;
+}
+
+void traverseType(QualType T, register_binding_flags &r) {
+  if (T->isIntegralOrEnumerationType() || T->isFloatingType()) {
+    r.contains_numeric = true;
+    return;
+  } else if (const RecordType *RT = T->getAs<RecordType>()) {
+    RecordDecl *SubRD = RT->getDecl();
+    if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(SubRD)) {
+      auto TheRecordDecl = TDecl->getSpecializedTemplate()->getTemplatedDecl();
+      TheRecordDecl = TheRecordDecl->getCanonicalDecl();
+      const auto *Attr = TheRecordDecl->getAttr<HLSLResourceAttr>();
+      llvm::hlsl::ResourceClass DeclResourceClass = Attr->getResourceClass();
+      switch (DeclResourceClass) {
+      case llvm::hlsl::ResourceClass::SRV: {
+        r.srv = true;
+        break;
+      }
+      case llvm::hlsl::ResourceClass::UAV: {
+        r.uav = true;
+        break;
+      }
+      case llvm::hlsl::ResourceClass::CBuffer: {
+        r.cbv = true;
+        break;
+      }
+      case llvm::hlsl::ResourceClass::Sampler: {
+        r.sampler = true;
+        break;
+      }
+      }
+    }
+
+    else if (SubRD->isCompleteDefinition()) {
+      for (auto Field : SubRD->fields()) {
+        QualType T = Field->getType();
+        traverseType(T, r);
+      }
+    }
+  }
+}
+
+void setResourceClassFlagsFromRecordDecl(register_binding_flags &r,
+                                         const RecordDecl *RD) {
+  if (!RD)
+    return;
+
+  if (RD->isCompleteDefinition()) {
+    for (auto Field : RD->fields()) {
+      QualType T = Field->getType();
+      traverseType(T, r);
+    }
+  }
+}
+
+register_binding_flags HLSLFillRegisterBindingFlags(Sema &S, Decl *D) {
+  register_binding_flags r;
+  if (!isDeclaredWithinCOrTBuffer(D)) {
+    // make sure the type is a basic / numeric type
+    if (VarDecl *v = dyn_cast<VarDecl>(D)) {
+      QualType t = v->getType();
+      // a numeric variable will inevitably end up in $Globals buffer
+      if (t->isIntegralType(S.getASTContext()) || t->isFloatingType())
+        r.default_globals = true;
+    }
+  }
+  // Cbuffers and Tbuffers are HLSLBufferDecl types
+  HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D);
+  // Samplers, UAVs, and SRVs are VarDecl types
+  VarDecl *SamplerUAVOrSRV = dyn_cast<VarDecl>(D);
+
+  if (CBufferOrTBuffer) {
+    r.resource = true;
+    if (CBufferOrTBuffer->isCBuffer())
+      r.cbv = true;
+    else
+      r.srv = true;
+  } else if (SamplerUAVOrSRV) {
+    const HLSLResourceAttr *res_attr =
+        getHLSLResourceAttrFromEitherDecl(SamplerUAVOrSRV, CBufferOrTBuffer);
+    if (res_attr) {
+      llvm::hlsl::ResourceClass DeclResourceClass =
+          res_attr->getResourceClass();
+      r.resource = true;
+      switch (DeclResourceClass) {
+      case llvm::hlsl::ResourceClass::SRV: {
+        r.srv = true;
+        break;
+      }
+      case llvm::hlsl::ResourceClass::UAV: {
+        r.uav = true;
+        break;
+      }
+      case llvm::hlsl::ResourceClass::CBuffer: {
+        r.cbv = true;
+        break;
+      }
+      case llvm::hlsl::ResourceClass::Sampler: {
+        r.sampler = true;
+        break;
+      }
+      }
+    } else {
+      if (SamplerUAVOrSRV->getType()->isBuiltinType())
+        r.basic = true;
+      else if (SamplerUAVOrSRV->getType()->isAggregateType()) {
+        r.udt = true;
+        QualType VarType = SamplerUAVOrSRV->getType();
+        if (const RecordType *RT = VarType->getAs<RecordType>()) {
+          const RecordDecl *RD = RT->getDecl();
+          // recurse through members, set appropriate resource class flags.
+          setResourceClassFlagsFromRecordDecl(r, RD);
+        }
+      } else
+        r.other = true;
+    }
+  } else {
+    llvm_unreachable("unknown decl type");
+  }
+  return r;
+}
+
+static void ValidateMultipleRegisterAnnotations(Sema &S, Decl *D,
+                                                StringRef &Slot) {
+  // make sure that there are no register annotations applied to the decl
+  // with the same register type but different numbers
+  std::unordered_map<char, std::set<char>>
+      s; // store unique register type + numbers
+  std::set<char> starting_set = {Slot[1]};
+  s.insert(std::make_pair(Slot[0], starting_set));
+  for (auto it = D->attr_begin(); it != D->attr_end(); ++it) {
+    if (HLSLResourceBindingAttr *attr =
+            dyn_cast<HLSLResourceBindingAttr>(*it)) {
+      std::string otherSlot(attr->getSlot().data());
+
+      // insert into hash map
+      if (s.find(otherSlot[0]) != s.end()) {
+        // if the register type is already in the map, insert the number
+        // into the set (if it's not already there
+        s[otherSlot[0]].insert(otherSlot[1]);
+      } else {
+        // if the register type is not in the map, insert it with the number
+        std::set<char> otherSet;
+        otherSet.insert(otherSlot[1]);
+        s.insert(std::make_pair(otherSlot[0], otherSet));
+      }
+    }
+  }
+
+  for (auto regType : s) {
+    if (regType.second.size() > 1) {
+      std::string regTypeStr(1, regType.first);
+      S.Diag(D->getLocation(), diag::err_hlsl_conflicting_register_annotations)
+          << regTypeStr;
+    }
+  }
+}
+
+static void DiagnoseHLSLResourceRegType(Sema &S, SourceLocation &ArgLoc,
+                                        Decl *D, StringRef &Slot) {
+
+  // Samplers, UAVs, and SRVs are VarDecl types
+  VarDecl *SamplerUAVOrSRV = dyn_cast<VarDecl>(D);
+  // Cbuffers and Tbuffers are HLSLBufferDecl types
+  HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D);
+
+  // exactly one of these two types should be set
+  if (!SamplerUAVOrSRV && !CBufferOrTBuffer)
+    return;
+  if (SamplerUAVOrSRV && CBufferOrTBuffer)
+    return;
----------------
bob80905 wrote:

Like you mentioned earlier, SamplerUAVOrSRV isn't a perfect variable name, but 
it represents the idea that the given decl is a VarDecl. In all the tests I've 
seen, register annotations may only ever be applied to VarDecls or 
HLSLBufferDecls. If none of these are set, then it is likely an error in the 
source passed to the compiler (which will likely emit an error before getting 
to this point).
If both are set, this is also unexpected, because HLSLBufferDecls are not 
VarDecls. They are their own unique data type and inherit Decl without going 
through VarDecl. 
So, these two checks are just checking that the register annotation is 
appropriately applied. 
On second thought, I think I should put an llvm_unreachable under each case.

https://github.com/llvm/llvm-project/pull/97103
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to