[clang] Implement resource binding type prefix mismatch flag setting logic (PR #97103)
Joshua Batista via cfe-commits
cfe-commits at lists.llvm.org
Mon Jul 8 18:02:52 PDT 2024
================
@@ -437,7 +460,406 @@ void SemaHLSL::handleShaderAttr(Decl *D, const ParsedAttr &AL) {
D->addAttr(NewAttr);
}
+struct register_binding_flags {
+ bool resource = false;
+ bool udt = false;
+ bool other = false;
+ bool basic = false;
+
+ bool srv = false;
+ bool uav = false;
+ bool cbv = false;
+ bool sampler = false;
+
+ bool contains_numeric = false;
+ bool default_globals = false;
+};
+
+bool isDeclaredWithinCOrTBuffer(const Decl *decl) {
+ if (!decl)
+ return false;
+
+ // Traverse up the parent contexts
+ const DeclContext *context = decl->getDeclContext();
+ while (context) {
+ if (isa<HLSLBufferDecl>(context)) {
+ return true;
+ }
+ context = context->getParent();
+ }
+
+ return false;
+}
+
+const CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *SamplerUAVOrSRV) {
+ const Type *Ty = SamplerUAVOrSRV->getType()->getPointeeOrArrayElementType();
+ if (!Ty)
+ llvm_unreachable("Resource class must have an element type.");
+
+ if (const BuiltinType *BTy = dyn_cast<BuiltinType>(Ty)) {
+ return nullptr;
+ }
+
+ const CXXRecordDecl *TheRecordDecl = Ty->getAsCXXRecordDecl();
+ if (!TheRecordDecl)
+ llvm_unreachable("Resource class should have a resource type declaration.");
+
+ if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(TheRecordDecl))
+ TheRecordDecl = TDecl->getSpecializedTemplate()->getTemplatedDecl();
+ TheRecordDecl = TheRecordDecl->getCanonicalDecl();
+ return TheRecordDecl;
+}
+
+const HLSLResourceAttr *
+getHLSLResourceAttrFromEitherDecl(VarDecl *SamplerUAVOrSRV,
+ HLSLBufferDecl *CBufferOrTBuffer) {
+
+ if (SamplerUAVOrSRV) {
+ const CXXRecordDecl *TheRecordDecl =
+ getRecordDeclFromVarDecl(SamplerUAVOrSRV);
+ if (!TheRecordDecl)
+ return nullptr;
+ const auto *Attr = TheRecordDecl->getAttr<HLSLResourceAttr>();
+ return Attr;
+ } else if (CBufferOrTBuffer) {
+ const auto *Attr = CBufferOrTBuffer->getAttr<HLSLResourceAttr>();
+ return Attr;
+ }
+ llvm_unreachable("one of the two conditions should be true.");
+ return nullptr;
+}
+
+void traverseType(QualType T, register_binding_flags &r) {
+ if (T->isIntegralOrEnumerationType() || T->isFloatingType()) {
+ r.contains_numeric = true;
+ return;
+ } else if (const RecordType *RT = T->getAs<RecordType>()) {
+ RecordDecl *SubRD = RT->getDecl();
+ if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(SubRD)) {
+ auto TheRecordDecl = TDecl->getSpecializedTemplate()->getTemplatedDecl();
+ TheRecordDecl = TheRecordDecl->getCanonicalDecl();
+ const auto *Attr = TheRecordDecl->getAttr<HLSLResourceAttr>();
+ llvm::hlsl::ResourceClass DeclResourceClass = Attr->getResourceClass();
+ switch (DeclResourceClass) {
+ case llvm::hlsl::ResourceClass::SRV: {
+ r.srv = true;
+ break;
+ }
+ case llvm::hlsl::ResourceClass::UAV: {
+ r.uav = true;
+ break;
+ }
+ case llvm::hlsl::ResourceClass::CBuffer: {
+ r.cbv = true;
+ break;
+ }
+ case llvm::hlsl::ResourceClass::Sampler: {
+ r.sampler = true;
+ break;
+ }
+ }
+ }
+
+ else if (SubRD->isCompleteDefinition()) {
+ for (auto Field : SubRD->fields()) {
+ QualType T = Field->getType();
+ traverseType(T, r);
+ }
+ }
+ }
+}
+
+void setResourceClassFlagsFromRecordDecl(register_binding_flags &r,
+ const RecordDecl *RD) {
+ if (!RD)
+ return;
+
+ if (RD->isCompleteDefinition()) {
+ for (auto Field : RD->fields()) {
+ QualType T = Field->getType();
+ traverseType(T, r);
+ }
+ }
+}
+
+register_binding_flags HLSLFillRegisterBindingFlags(Sema &S, Decl *D) {
+ register_binding_flags r;
+ if (!isDeclaredWithinCOrTBuffer(D)) {
+ // make sure the type is a basic / numeric type
+ if (VarDecl *v = dyn_cast<VarDecl>(D)) {
+ QualType t = v->getType();
+ // a numeric variable will inevitably end up in $Globals buffer
+ if (t->isIntegralType(S.getASTContext()) || t->isFloatingType())
+ r.default_globals = true;
+ }
+ }
+ // Cbuffers and Tbuffers are HLSLBufferDecl types
+ HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D);
+ // Samplers, UAVs, and SRVs are VarDecl types
+ VarDecl *SamplerUAVOrSRV = dyn_cast<VarDecl>(D);
+
+ if (CBufferOrTBuffer) {
+ r.resource = true;
+ if (CBufferOrTBuffer->isCBuffer())
+ r.cbv = true;
+ else
+ r.srv = true;
+ } else if (SamplerUAVOrSRV) {
+ const HLSLResourceAttr *res_attr =
+ getHLSLResourceAttrFromEitherDecl(SamplerUAVOrSRV, CBufferOrTBuffer);
+ if (res_attr) {
+ llvm::hlsl::ResourceClass DeclResourceClass =
+ res_attr->getResourceClass();
+ r.resource = true;
+ switch (DeclResourceClass) {
+ case llvm::hlsl::ResourceClass::SRV: {
+ r.srv = true;
+ break;
+ }
+ case llvm::hlsl::ResourceClass::UAV: {
+ r.uav = true;
+ break;
+ }
+ case llvm::hlsl::ResourceClass::CBuffer: {
+ r.cbv = true;
+ break;
+ }
+ case llvm::hlsl::ResourceClass::Sampler: {
+ r.sampler = true;
+ break;
+ }
+ }
+ } else {
+ if (SamplerUAVOrSRV->getType()->isBuiltinType())
+ r.basic = true;
+ else if (SamplerUAVOrSRV->getType()->isAggregateType()) {
+ r.udt = true;
+ QualType VarType = SamplerUAVOrSRV->getType();
+ if (const RecordType *RT = VarType->getAs<RecordType>()) {
+ const RecordDecl *RD = RT->getDecl();
+ // recurse through members, set appropriate resource class flags.
+ setResourceClassFlagsFromRecordDecl(r, RD);
+ }
+ } else
+ r.other = true;
+ }
+ } else {
+ llvm_unreachable("unknown decl type");
+ }
+ return r;
+}
+
+static void ValidateMultipleRegisterAnnotations(Sema &S, Decl *D,
+ StringRef &Slot) {
+ // make sure that there are no register annotations applied to the decl
+ // with the same register type but different numbers
+ std::unordered_map<char, std::set<char>>
+ s; // store unique register type + numbers
+ std::set<char> starting_set = {Slot[1]};
+ s.insert(std::make_pair(Slot[0], starting_set));
+ for (auto it = D->attr_begin(); it != D->attr_end(); ++it) {
+ if (HLSLResourceBindingAttr *attr =
+ dyn_cast<HLSLResourceBindingAttr>(*it)) {
+ std::string otherSlot(attr->getSlot().data());
+
+ // insert into hash map
+ if (s.find(otherSlot[0]) != s.end()) {
+ // if the register type is already in the map, insert the number
+ // into the set (if it's not already there
+ s[otherSlot[0]].insert(otherSlot[1]);
+ } else {
+ // if the register type is not in the map, insert it with the number
+ std::set<char> otherSet;
+ otherSet.insert(otherSlot[1]);
+ s.insert(std::make_pair(otherSlot[0], otherSet));
+ }
+ }
+ }
+
+ for (auto regType : s) {
+ if (regType.second.size() > 1) {
+ std::string regTypeStr(1, regType.first);
+ S.Diag(D->getLocation(), diag::err_hlsl_conflicting_register_annotations)
+ << regTypeStr;
+ }
+ }
+}
+
+static void DiagnoseHLSLResourceRegType(Sema &S, SourceLocation &ArgLoc,
+ Decl *D, StringRef &Slot) {
+
+ // Samplers, UAVs, and SRVs are VarDecl types
+ VarDecl *SamplerUAVOrSRV = dyn_cast<VarDecl>(D);
+ // Cbuffers and Tbuffers are HLSLBufferDecl types
+ HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D);
+
+ // exactly one of these two types should be set
+ if (!SamplerUAVOrSRV && !CBufferOrTBuffer)
+ return;
+ if (SamplerUAVOrSRV && CBufferOrTBuffer)
+ return;
+
+ register_binding_flags f = HLSLFillRegisterBindingFlags(S, D);
+ assert((int)f.other + (int)f.resource + (int)f.basic + (int)f.udt == 1 &&
+ "only one resource analysis result should be expected");
+
+ // get the variable type
+ std::string typestr;
+ if (SamplerUAVOrSRV) {
+ QualType QT = SamplerUAVOrSRV->getType();
+ PrintingPolicy PP = S.getPrintingPolicy();
+ typestr = QualType::getAsString(QT.split(), PP);
+ } else
+ typestr = CBufferOrTBuffer->isCBuffer() ? "cbuffer" : "tbuffer";
----------------
bob80905 wrote:
Made a function to fetch the relevant data only at diagnostic emission time.
https://github.com/llvm/llvm-project/pull/97103
More information about the cfe-commits
mailing list