[clang] [HLSL] Collect explicit resource binding information (PR #111203)
Justin Bogner via cfe-commits
cfe-commits at lists.llvm.org
Tue Oct 15 08:33:55 PDT 2024
================
@@ -985,88 +1026,92 @@ SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) {
return LocInfo;
}
-// get the record decl from a var decl that we expect
-// represents a resource
-static CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *VD) {
- const Type *Ty = VD->getType()->getPointeeOrArrayElementType();
- assert(Ty && "Resource must have an element type.");
-
- if (Ty->isBuiltinType())
- return nullptr;
-
- CXXRecordDecl *TheRecordDecl = Ty->getAsCXXRecordDecl();
- assert(TheRecordDecl && "Resource should have a resource type declaration.");
- return TheRecordDecl;
-}
-
+// Returns handle type of a resource, if the type is a resource
static const HLSLAttributedResourceType *
-findAttributedResourceTypeOnField(VarDecl *VD) {
- assert(VD != nullptr && "expected VarDecl");
- if (RecordDecl *RD = getRecordDeclFromVarDecl(VD)) {
- for (auto *FD : RD->fields()) {
- if (const HLSLAttributedResourceType *AttrResType =
- dyn_cast<HLSLAttributedResourceType>(FD->getType().getTypePtr()))
- return AttrResType;
+FindHandleTypeOnResource(const Type *Ty) {
+ // If Ty is a resource class, the first field must
+ // be the resource handle of type HLSLAttributedResourceType
+ if (RecordDecl *RD = Ty->getAsCXXRecordDecl()) {
+ if (!RD->fields().empty()) {
+ const auto &FirstFD = RD->fields().begin();
+ return dyn_cast<HLSLAttributedResourceType>(
+ FirstFD->getType().getTypePtr());
}
}
return nullptr;
}
-// Iterate over RecordType fields and return true if any of them matched the
-// register type
-static bool ContainsResourceForRegisterType(Sema &S, const RecordType *RT,
- RegisterType RegType) {
- llvm::SmallVector<const Type *> TypesToScan;
- TypesToScan.emplace_back(RT);
-
- while (!TypesToScan.empty()) {
- const Type *T = TypesToScan.pop_back_val();
- while (T->isArrayType())
- T = T->getArrayElementTypeNoTypeQual();
- if (T->isIntegralOrEnumerationType() || T->isFloatingType()) {
- if (RegType == RegisterType::C)
- return true;
+// Returns handle type of a resource, if the VarDecl is a resource
+static const HLSLAttributedResourceType *
+FindHandleTypeOnResource(const VarDecl *VD) {
+ assert(VD != nullptr && "expected VarDecl");
+ return FindHandleTypeOnResource(VD->getType().getTypePtr());
+}
+
+// Walks though the global variable declaration, collects all resource binding
+// requirements and adds them to Bindings
+void SemaHLSL::FindResourcesOnUserRecordDecl(const VarDecl *VD,
+ const RecordType *RT, int Size) {
+ const RecordDecl *RD = RT->getDecl();
+ for (FieldDecl *FD : RD->fields()) {
+ const Type *Ty = FD->getType()->getUnqualifiedDesugaredType();
+
+ // Calculate array size and unwrap
+ int ArraySize = 1;
+ assert(!Ty->isIncompleteArrayType() &&
+ "incomplete arrays inside user defined types are not supported");
+ while (Ty->isConstantArrayType()) {
+ const ConstantArrayType *CAT = cast<ConstantArrayType>(Ty);
+ ArraySize *= CAT->getSize().getSExtValue();
+ Ty = CAT->getElementType()->getUnqualifiedDesugaredType();
}
- const RecordType *RT = T->getAs<RecordType>();
- if (!RT)
+
+ if (!Ty->isRecordType())
continue;
- const RecordDecl *RD = RT->getDecl();
- for (FieldDecl *FD : RD->fields()) {
- const Type *FieldTy = FD->getType().getTypePtr();
- if (const HLSLAttributedResourceType *AttrResType =
- dyn_cast<HLSLAttributedResourceType>(FieldTy)) {
- ResourceClass RC = AttrResType->getAttrs().ResourceClass;
- if (getRegisterType(RC) == RegType)
- return true;
- } else {
- TypesToScan.emplace_back(FD->getType().getTypePtr());
- }
+ // Field is a resource or array of resources
+ if (const HLSLAttributedResourceType *AttrResType =
+ FindHandleTypeOnResource(Ty)) {
+ ResourceClass RC = AttrResType->getAttrs().ResourceClass;
+
+ // Add a new DeclBindingInfo to Bindings. Update the binding size if
+ // a binding info already exists (there are multiple resources of same
+ // resource class in this user decl)
+ if (auto *DBI = Bindings.getDeclBindingInfo(VD, RC))
+ DBI->Size += Size * ArraySize;
+ else
+ Bindings.addDeclBindingInfo(VD, RC, Size);
+ } else if (const RecordType *RT = dyn_cast<RecordType>(Ty)) {
+ // Recursively scan embedded struct or class; it would be nice to do this
+ // without recursion, but tricky to corrently calculate the size.
+ // Hopefully nesting of structs in structs too many levels is unlikely.
+ FindResourcesOnUserRecordDecl(VD, RT, Size);
}
}
- return false;
}
-static void CheckContainsResourceForRegisterType(Sema &S,
- SourceLocation &ArgLoc,
- Decl *D, RegisterType RegType,
- bool SpecifiedSpace) {
+// return false if the register binding is not valid
+static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc,
+ Decl *D, RegisterType RegType,
+ bool SpecifiedSpace) {
----------------
bogner wrote:
If we're going to add a comment here it should really say what the function does, not just what the return value means.
https://github.com/llvm/llvm-project/pull/111203
More information about the cfe-commits
mailing list