[clang] Implement resource binding type prefix mismatch diagnostic infrastructure (PR #97103)
Damyan Pepper via cfe-commits
cfe-commits at lists.llvm.org
Tue Jul 9 18:11:21 PDT 2024
================
@@ -437,7 +444,419 @@ void SemaHLSL::handleShaderAttr(Decl *D, const ParsedAttr &AL) {
D->addAttr(NewAttr);
}
+struct RegisterBindingFlags {
+ bool Resource = false;
+ bool UDT = false;
+ bool Other = false;
+ bool Basic = false;
+
+ bool SRV = false;
+ bool UAV = false;
+ bool CBV = false;
+ bool Sampler = false;
+
+ bool ContainsNumeric = false;
+ bool DefaultGlobals = false;
+};
+
+bool isDeclaredWithinCOrTBuffer(const Decl *decl) {
+ if (!decl)
+ return false;
+
+ // Traverse up the parent contexts
+ const DeclContext *context = decl->getDeclContext();
+ if (isa<HLSLBufferDecl>(context)) {
+ return true;
+ }
+
+ return false;
+}
+
+const CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *VD) {
+ const Type *Ty = VD->getType()->getPointeeOrArrayElementType();
+ assert(Ty && "Resource class must have an element type.");
+
+ if (const auto *BTy = dyn_cast<BuiltinType>(Ty))
+ return nullptr;
+
+ const CXXRecordDecl *TheRecordDecl = Ty->getAsCXXRecordDecl();
+ assert(TheRecordDecl &&
+ "Resource class should have a resource type declaration.");
+
+ if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(TheRecordDecl))
+ TheRecordDecl = TDecl->getSpecializedTemplate()->getTemplatedDecl();
+ TheRecordDecl = TheRecordDecl->getCanonicalDecl();
+ return TheRecordDecl;
+}
+
+const HLSLResourceAttr *
+getHLSLResourceAttrFromEitherDecl(VarDecl *VD,
+ HLSLBufferDecl *CBufferOrTBuffer) {
+
+ if (VD) {
+ const CXXRecordDecl *TheRecordDecl = getRecordDeclFromVarDecl(VD);
+ if (!TheRecordDecl)
+ return nullptr;
+ const auto *Attr = TheRecordDecl->getAttr<HLSLResourceAttr>();
+ return Attr;
+ } else if (CBufferOrTBuffer) {
+ const auto *Attr = CBufferOrTBuffer->getAttr<HLSLResourceAttr>();
+ return Attr;
+ }
+ llvm_unreachable("one of the two conditions should be true.");
+ return nullptr;
+}
+
+void traverseType(QualType T, RegisterBindingFlags &r) {
+ if (T->isIntegralOrEnumerationType() || T->isFloatingType()) {
+ r.ContainsNumeric = true;
+ return;
+ }
+ const RecordType *RT = T->getAs<RecordType>();
+ if (!RT)
+ return;
+
+ RecordDecl *SubRD = RT->getDecl();
+ if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(SubRD)) {
+ auto TheRecordDecl = TDecl->getSpecializedTemplate()->getTemplatedDecl();
+ TheRecordDecl = TheRecordDecl->getCanonicalDecl();
+ const auto *Attr = TheRecordDecl->getAttr<HLSLResourceAttr>();
+ llvm::hlsl::ResourceClass DeclResourceClass = Attr->getResourceClass();
+ switch (DeclResourceClass) {
+ case llvm::hlsl::ResourceClass::SRV:
+ r.SRV = true;
+ break;
+ case llvm::hlsl::ResourceClass::UAV:
+ r.UAV = true;
+ break;
+ case llvm::hlsl::ResourceClass::CBuffer:
+ r.CBV = true;
+ break;
+ case llvm::hlsl::ResourceClass::Sampler:
+ r.Sampler = true;
+ break;
+ }
+ }
+
+ else if (SubRD->isCompleteDefinition()) {
+ for (auto Field : SubRD->fields()) {
+ QualType T = Field->getType();
+ traverseType(T, r);
+ }
+ }
+}
+
+void setResourceClassFlagsFromRecordDecl(RegisterBindingFlags &r,
+ const RecordDecl *RD) {
+ if (!RD)
+ return;
+
+ if (RD->isCompleteDefinition()) {
+ for (auto Field : RD->fields()) {
+ QualType T = Field->getType();
+ traverseType(T, r);
+ }
+ }
+}
+
+RegisterBindingFlags HLSLFillRegisterBindingFlags(Sema &S, Decl *D) {
+ RegisterBindingFlags r;
+ if (!isDeclaredWithinCOrTBuffer(D)) {
+ // make sure the type is a basic / numeric type
+ if (VarDecl *v = dyn_cast<VarDecl>(D)) {
+ QualType t = v->getType();
+ // a numeric variable will inevitably end up in $Globals buffer
+ if (t->isIntegralType(S.getASTContext()) || t->isFloatingType())
+ r.DefaultGlobals = true;
+ }
+ }
+ // Cbuffers and Tbuffers are HLSLBufferDecl types
+ HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D);
+ // Samplers, UAVs, and SRVs are VarDecl types
+ VarDecl *VD = dyn_cast<VarDecl>(D);
+
+ assert(((VD && !CBufferOrTBuffer) || (!VD && CBufferOrTBuffer)) &&
+ "either VD or CBufferOrTBuffer should be set");
+
+ if (CBufferOrTBuffer) {
+ r.Resource = true;
+ if (CBufferOrTBuffer->isCBuffer())
+ r.CBV = true;
+ else
+ r.SRV = true;
+ } else if (VD) {
+ const HLSLResourceAttr *res_attr =
+ getHLSLResourceAttrFromEitherDecl(VD, CBufferOrTBuffer);
+ if (res_attr) {
+ llvm::hlsl::ResourceClass DeclResourceClass =
+ res_attr->getResourceClass();
+ r.Resource = true;
+ switch (DeclResourceClass) {
+ case llvm::hlsl::ResourceClass::SRV: {
+ r.SRV = true;
+ break;
+ }
+ case llvm::hlsl::ResourceClass::UAV: {
+ r.UAV = true;
+ break;
+ }
+ case llvm::hlsl::ResourceClass::CBuffer: {
+ r.CBV = true;
+ break;
+ }
+ case llvm::hlsl::ResourceClass::Sampler: {
+ r.Sampler = true;
+ break;
+ }
+ }
+ } else {
+ if (VD->getType()->isBuiltinType())
+ r.Basic = true;
+ else if (VD->getType()->isAggregateType()) {
+ r.UDT = true;
+ QualType VarType = VD->getType();
+ if (const RecordType *RT = VarType->getAs<RecordType>()) {
+ const RecordDecl *RD = RT->getDecl();
+ // recurse through members, set appropriate resource class flags.
+ setResourceClassFlagsFromRecordDecl(r, RD);
+ }
+ } else
+ r.Other = true;
+ }
+ }
+ return r;
+}
+
+int getRegisterTypeIndex(StringRef Slot) {
+ switch (Slot[0]) {
+ case 't':
+ case 'T':
+ return 0;
+ case 'u':
+ case 'U':
+ return 1;
+ case 'b':
+ case 'B ':
+ return 2;
+ case 's':
+ case 'S':
+ return 3;
+ case 'c':
+ case 'C':
+ return 4;
+ case 'i':
+ case 'I':
+ return 5;
+ default:
+ llvm_unreachable("invalid register type");
+ }
+}
+
+static void ValidateMultipleRegisterAnnotations(Sema &S, Decl *D,
+ StringRef &Slot) {
+ // make sure that there are no tworegister annotations
+ // applied to the decl with the same register type
+ bool RegisterTypesDetected[6] = {false};
+ RegisterTypesDetected[getRegisterTypeIndex(Slot)] = true;
+
+ for (auto it = D->attr_begin(); it != D->attr_end(); ++it) {
+ if (HLSLResourceBindingAttr *attr =
+ dyn_cast<HLSLResourceBindingAttr>(*it)) {
+
+ int registerTypeIndex = getRegisterTypeIndex(attr->getSlot());
+ if (RegisterTypesDetected[registerTypeIndex]) {
+ S.Diag(D->getLocation(),
+ diag::err_hlsl_conflicting_register_annotations)
+ << attr->getSlot().substr(0, 1);
+ } else {
+ RegisterTypesDetected[registerTypeIndex] = true;
+ }
+ }
+ }
+}
+
+std::string getHLSLResourceTypeStr(Sema &S, Decl *D) {
+ VarDecl *VD = dyn_cast<VarDecl>(D);
+ HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D);
+
+ if (VD) {
+ QualType QT = VD->getType();
+ PrintingPolicy PP = S.getPrintingPolicy();
+ return QualType::getAsString(QT.split(), PP);
+ } else {
+ return CBufferOrTBuffer->isCBuffer() ? "cbuffer" : "tbuffer";
+ }
+}
+
+static void DiagnoseHLSLResourceRegType(Sema &S, SourceLocation &ArgLoc,
+ Decl *D, StringRef &Slot) {
+
+ // Samplers, UAVs, and SRVs are VarDecl types
+ VarDecl *VD = dyn_cast<VarDecl>(D);
+ // Cbuffers and Tbuffers are HLSLBufferDecl types
+ HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D);
+
+ // exactly one of these two types should be set
+ assert(((VD && !CBufferOrTBuffer) || (!VD && CBufferOrTBuffer)) &&
+ "either VD or CBufferOrTBuffer should be set");
+
+ RegisterBindingFlags f = HLSLFillRegisterBindingFlags(S, D);
+ assert((int)f.Other + (int)f.Resource + (int)f.Basic + (int)f.UDT == 1 &&
+ "only one resource analysis result should be expected");
+
+ std::string registerType(Slot.substr(0, 1));
+
+ // first, if "other" is set, emit an error
+ if (f.Other) {
+ S.Diag(ArgLoc, diag::err_hlsl_unsupported_register_type_and_variable_type)
+ << Slot << getHLSLResourceTypeStr(S, D);
+ return;
+ }
+
+ // next, if multiple register annotations exist, check that none conflict.
+ ValidateMultipleRegisterAnnotations(S, D, Slot);
+
+ // next, if resource is set, make sure the register type in the register
+ // annotation is compatible with the variable's resource type.
+ if (f.Resource) {
+ const HLSLResourceAttr *res_attr =
+ getHLSLResourceAttrFromEitherDecl(VD, CBufferOrTBuffer);
+ assert(res_attr && "any decl that set the resource flag on analysis should "
+ "have a resource attribute attached.");
+ const llvm::hlsl::ResourceClass DeclResourceClass =
+ res_attr->getResourceClass();
+
+ switch (DeclResourceClass) {
+ case llvm::hlsl::ResourceClass::SRV: {
+ if (Slot[0] != 't') {
----------------
damyanp wrote:
These tests here are case-sensitive.
I recommend you centralize the interpretation of a Slot into a register type.
https://github.com/llvm/llvm-project/pull/97103
More information about the cfe-commits
mailing list