[clang] Implement resource binding type prefix mismatch diagnostic infrastructure (PR #97103)
Damyan Pepper via cfe-commits
cfe-commits at lists.llvm.org
Mon Aug 19 17:53:59 PDT 2024
================
@@ -459,7 +468,413 @@ void SemaHLSL::handleResourceClassAttr(Decl *D, const ParsedAttr &AL) {
D->addAttr(HLSLResourceClassAttr::Create(getASTContext(), RC, ArgLoc));
}
-void SemaHLSL::handleResourceBindingAttr(Decl *D, const ParsedAttr &AL) {
+struct RegisterBindingFlags {
+ bool Resource = false;
+ bool UDT = false;
+ bool Other = false;
+ bool Basic = false;
+
+ bool SRV = false;
+ bool UAV = false;
+ bool CBV = false;
+ bool Sampler = false;
+
+ bool ContainsNumeric = false;
+ bool DefaultGlobals = false;
+};
+
+static bool isDeclaredWithinCOrTBuffer(const Decl *TheDecl) {
+ if (!TheDecl)
+ return false;
+
+ // Traverse up the parent contexts
+ const DeclContext *context = TheDecl->getDeclContext();
+ if (isa<HLSLBufferDecl>(context)) {
+ return true;
+ }
+
+ return false;
+}
+
+static CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *VD) {
+ const Type *Ty = VD->getType()->getPointeeOrArrayElementType();
+ assert(Ty && "Resource class must have an element type.");
+
+ if (const auto *TheBuiltinTy = dyn_cast<BuiltinType>(Ty))
+ return nullptr;
+
+ CXXRecordDecl *TheRecordDecl = Ty->getAsCXXRecordDecl();
+ assert(TheRecordDecl &&
+ "Resource class should have a resource type declaration.");
+ return TheRecordDecl;
+}
+
+static void setResourceClassFlagsFromDeclResourceClass(
+ RegisterBindingFlags &Flags, llvm::hlsl::ResourceClass DeclResourceClass) {
+ switch (DeclResourceClass) {
+ case llvm::hlsl::ResourceClass::SRV:
+ Flags.SRV = true;
+ break;
+ case llvm::hlsl::ResourceClass::UAV:
+ Flags.UAV = true;
+ break;
+ case llvm::hlsl::ResourceClass::CBuffer:
+ Flags.CBV = true;
+ break;
+ case llvm::hlsl::ResourceClass::Sampler:
+ Flags.Sampler = true;
+ break;
+ }
+}
+
+template <typename T>
+static const T *
+getSpecifiedHLSLAttrFromVarDeclOrRecordDecl(VarDecl *VD,
+ RecordDecl *TheRecordDecl) {
+ if (VD) {
+ TheRecordDecl = getRecordDeclFromVarDecl(VD);
+ if (!TheRecordDecl)
+ return nullptr;
+ }
+
+ // make a lambda that loops over the field members and checks for the
+ // templated attribute
+ auto f = [](RecordDecl *TheRecordDecl) -> const T * {
+ for (auto *FD : TheRecordDecl->fields()) {
+ const T *Attr = FD->getAttr<T>();
+ if (Attr)
+ return Attr;
+ }
+ return nullptr;
+ };
+
+ if (TheRecordDecl) {
+ // if the member's base type is a ClassTemplateSpecializationDecl,
+ // check if it has a member handle with a resource class attr
+ // this is necessary while resources like RWBuffer are defined externally
+ if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(TheRecordDecl)) {
+ auto TheCXXRecordDecl =
+ TDecl->getSpecializedTemplate()->getTemplatedDecl();
+ TheCXXRecordDecl = TheCXXRecordDecl->getCanonicalDecl();
+
+ return f(TheCXXRecordDecl);
+ }
+
+ return f(TheRecordDecl);
+ }
+ llvm_unreachable("TheRecordDecl should not be null");
+ return nullptr;
+}
+
+static void setFlagsFromType(QualType TheQualTy, RegisterBindingFlags &Flags) {
+ // if the member's type is a numeric type, set the ContainsNumeric flag
+ if (TheQualTy->isIntegralOrEnumerationType() || TheQualTy->isFloatingType()) {
+ Flags.ContainsNumeric = true;
+ return;
+ }
+
+ // otherwise, if the member's base type is not a record type, return
+ const clang::Type *TheBaseType = TheQualTy.getTypePtr();
+ while (TheBaseType->isArrayType())
+ TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual();
+
+ const RecordType *TheRecordTy = TheBaseType->getAs<RecordType>();
+ if (!TheRecordTy)
+ return;
+
+ RecordDecl *SubRecordDecl = TheRecordTy->getDecl();
+ bool resClassSet = false;
+ const HLSLResourceClassAttr *Attr =
+ getSpecifiedHLSLAttrFromVarDeclOrRecordDecl<HLSLResourceClassAttr>(
+ nullptr, SubRecordDecl);
+ // find the attr if it's on the member (the handle) of the resource
+ if (Attr) {
+ llvm::hlsl::ResourceClass DeclResourceClass = Attr->getResourceClass();
+ setResourceClassFlagsFromDeclResourceClass(Flags, DeclResourceClass);
+ resClassSet = true;
+ }
+ // otherwise, check if the member of the UDT itself has a resource class attr
+ else if (const auto *Attr = SubRecordDecl->getAttr<HLSLResourceClassAttr>()) {
+ llvm::hlsl::ResourceClass DeclResourceClass = Attr->getResourceClass();
+ setResourceClassFlagsFromDeclResourceClass(Flags, DeclResourceClass);
+ resClassSet = true;
+ }
+ // recurse if there are more fields to analyze
+ if (!resClassSet) {
+ for (auto Field : SubRecordDecl->fields()) {
+ setFlagsFromType(Field->getType(), Flags);
+ }
+ }
+}
+
+static void setResourceClassFlagsFromRecordDecl(RegisterBindingFlags &Flags,
+ const RecordDecl *RD) {
+ if (!RD)
+ return;
+
+ if (RD->isCompleteDefinition()) {
+ for (auto Field : RD->fields()) {
+ QualType T = Field->getType();
+ setFlagsFromType(T, Flags);
+ }
+ }
+}
+
+static RegisterBindingFlags HLSLFillRegisterBindingFlags(Sema &S,
+ Decl *TheDecl) {
+
+ // Cbuffers and Tbuffers are HLSLBufferDecl types
+ HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl);
+ // Samplers, UAVs, and SRVs are VarDecl types
+ VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl);
+
+ assert(((TheVarDecl && !CBufferOrTBuffer) ||
+ (!TheVarDecl && CBufferOrTBuffer)) &&
+ "either TheVarDecl or CBufferOrTBuffer should be set");
+
+ RegisterBindingFlags Flags;
+
+ // check if the decl type is groupshared
+ if (TheDecl->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) {
+ Flags.Other = true;
+ return Flags;
+ }
+
+ if (!isDeclaredWithinCOrTBuffer(TheDecl)) {
+ // make sure the type is a basic / numeric type
+ if (TheVarDecl) {
+ QualType TheQualTy = TheVarDecl->getType();
+ // a numeric variable or an array of numeric variables
+ // will inevitably end up in $Globals buffer
+ const clang::Type *TheBaseType = TheQualTy.getTypePtr();
+ while (TheBaseType->isArrayType())
+ TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual();
+ if (TheBaseType->isIntegralType(S.getASTContext()) ||
+ TheBaseType->isFloatingType())
+ Flags.DefaultGlobals = true;
+ }
+ }
+
+ if (CBufferOrTBuffer) {
+ Flags.Resource = true;
+ if (CBufferOrTBuffer->isCBuffer())
+ Flags.CBV = true;
+ else
+ Flags.SRV = true;
+ } else if (TheVarDecl) {
+ const HLSLResourceClassAttr *resClassAttr =
+ getSpecifiedHLSLAttrFromVarDeclOrRecordDecl<HLSLResourceClassAttr>(
+ TheVarDecl, nullptr);
+ const clang::Type *TheBaseType = TheVarDecl->getType().getTypePtr();
+ while (TheBaseType->isArrayType())
+ TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual();
+
+ if (resClassAttr) {
+ llvm::hlsl::ResourceClass DeclResourceClass =
+ resClassAttr->getResourceClass();
+ Flags.Resource = true;
+ setResourceClassFlagsFromDeclResourceClass(Flags, DeclResourceClass);
+ } else {
+ if (TheBaseType->isArithmeticType())
+ Flags.Basic = true;
+ else if (TheBaseType->isRecordType()) {
+ Flags.UDT = true;
+ const RecordType *TheRecordTy = TheBaseType->getAs<RecordType>();
+ assert(TheRecordTy && "The Qual Type should be Record Type");
+ const RecordDecl *TheRecordDecl = TheRecordTy->getDecl();
+ // recurse through members, set appropriate resource class flags.
+ setResourceClassFlagsFromRecordDecl(Flags, TheRecordDecl);
+ } else
+ Flags.Other = true;
+ }
+ }
+ return Flags;
+}
+
+enum RegisterType { SRV, UAV, CBuffer, Sampler, C, I };
+
+static RegisterType getRegisterTypeIndex(StringRef Slot) {
+ switch (Slot[0]) {
+ case 't':
+ case 'T':
+ return RegisterType::SRV;
+ case 'u':
+ case 'U':
+ return RegisterType::UAV;
+ case 'b':
+ case 'B ':
+ return RegisterType::CBuffer;
+ case 's':
+ case 'S':
+ return RegisterType::Sampler;
+ case 'c':
+ case 'C':
+ return RegisterType::C;
+ // we don't need to check for 'i' here, because
+ // any attribute that has the 'i' register type
+ // will be immediately caught by handleResourceBindingAttr
+ // so it's impossible for the decl to already have an 'i' register type
+ default:
+ llvm_unreachable("invalid register type");
+ }
+}
+
+static void ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
+ RegisterType regType) {
+ // make sure that there are no two register annotations
+ // applied to the decl with the same register type
+ bool RegisterTypesDetected[5] = {false};
+ RegisterTypesDetected[regType] = true;
+
+ // we need a static map to keep track of previous conflicts
+ // so that we don't emit the same error multiple times
+ static std::map<Decl *, std::set<RegisterType>> PreviousConflicts;
+
+ for (auto it = TheDecl->attr_begin(); it != TheDecl->attr_end(); ++it) {
+ if (HLSLResourceBindingAttr *attr =
+ dyn_cast<HLSLResourceBindingAttr>(*it)) {
+
+ RegisterType regType = getRegisterTypeIndex(attr->getSlot());
+ if (RegisterTypesDetected[regType]) {
+ if (PreviousConflicts[TheDecl].count(regType))
+ continue;
+ S.Diag(TheDecl->getLocation(),
+ diag::err_hlsl_duplicate_register_annotation)
+ << regType;
+ PreviousConflicts[TheDecl].insert(regType);
+ } else {
+ RegisterTypesDetected[regType] = true;
+ }
+ }
+ }
+}
+
+static std::string getHLSLResourceTypeStr(Sema &S, Decl *TheDecl) {
+ VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl);
+ HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl);
+
+ if (TheVarDecl) {
+ QualType TheQualTy = TheVarDecl->getType();
+ PrintingPolicy PP = S.getPrintingPolicy();
+ return QualType::getAsString(TheQualTy.split(), PP);
+ } else {
+ return CBufferOrTBuffer->isCBuffer() ? "cbuffer" : "tbuffer";
+ }
+}
+
+static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
+ Decl *TheDecl, RegisterType regType) {
+
+ // Samplers, UAVs, and SRVs are VarDecl types
+ VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl);
+ // Cbuffers and Tbuffers are HLSLBufferDecl types
+ HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl);
+
+ // exactly one of these two types should be set
+ assert(((TheVarDecl && !CBufferOrTBuffer) ||
+ (!TheVarDecl && CBufferOrTBuffer)) &&
+ "either TheVarDecl or CBufferOrTBuffer should be set");
+
+ RegisterBindingFlags Flags = HLSLFillRegisterBindingFlags(S, TheDecl);
+ assert((int)Flags.Other + (int)Flags.Resource + (int)Flags.Basic +
+ (int)Flags.UDT ==
+ 1 &&
+ "only one resource analysis result should be expected");
+
+ // first, if "other" is set, emit an error
+ if (Flags.Other) {
+ S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << regType;
+ return;
+ }
+
+ // next, if multiple register annotations exist, check that none conflict.
+ ValidateMultipleRegisterAnnotations(S, TheDecl, regType);
+
+ // next, if resource is set, make sure the register type in the register
+ // annotation is compatible with the variable's resource type.
+ if (Flags.Resource) {
+ const HLSLResourceAttr *resAttr = nullptr;
+ const HLSLResourceClassAttr *resClassAttr = nullptr;
+ if (CBufferOrTBuffer) {
+ resAttr = CBufferOrTBuffer->getAttr<HLSLResourceAttr>();
+ resClassAttr = CBufferOrTBuffer->getAttr<HLSLResourceClassAttr>();
+ } else if (TheVarDecl) {
+ resAttr = getSpecifiedHLSLAttrFromVarDeclOrRecordDecl<HLSLResourceAttr>(
+ TheVarDecl, nullptr);
+ resClassAttr =
+ getSpecifiedHLSLAttrFromVarDeclOrRecordDecl<HLSLResourceClassAttr>(
+ TheVarDecl, nullptr);
+ }
+
+ assert(resAttr && resClassAttr &&
+ "any decl that set the resource flag on analysis should "
+ "have a resource attribute and resource class attribute attached.");
+ const llvm::hlsl::ResourceClass DeclResourceClass =
+ resClassAttr->getResourceClass();
+
+ // confirm that the register type is bound to its expected resource class
+ static llvm::SmallVector<RegisterType>
+ ExpectedRegisterTypesForResourceClass = {
+ RegisterType::SRV,
+ RegisterType::UAV,
+ RegisterType::CBuffer,
+ RegisterType::Sampler,
+ };
----------------
damyanp wrote:
Any reason this can't be a c-style array? It is just for `.size()`. Could use `std::size(ExpectedRegisterTypesForResourceClass)` for that?
https://github.com/llvm/llvm-project/pull/97103
More information about the cfe-commits
mailing list