[clang] [HLSL] Collect explicit resource binding information (PR #111203)

Justin Bogner via cfe-commits cfe-commits at lists.llvm.org
Tue Oct 15 08:33:55 PDT 2024


================
@@ -985,88 +1026,92 @@ SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) {
   return LocInfo;
 }
 
-// get the record decl from a var decl that we expect
-// represents a resource
-static CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *VD) {
-  const Type *Ty = VD->getType()->getPointeeOrArrayElementType();
-  assert(Ty && "Resource must have an element type.");
-
-  if (Ty->isBuiltinType())
-    return nullptr;
-
-  CXXRecordDecl *TheRecordDecl = Ty->getAsCXXRecordDecl();
-  assert(TheRecordDecl && "Resource should have a resource type declaration.");
-  return TheRecordDecl;
-}
-
+// Returns handle type of a resource, if the type is a resource
 static const HLSLAttributedResourceType *
-findAttributedResourceTypeOnField(VarDecl *VD) {
-  assert(VD != nullptr && "expected VarDecl");
-  if (RecordDecl *RD = getRecordDeclFromVarDecl(VD)) {
-    for (auto *FD : RD->fields()) {
-      if (const HLSLAttributedResourceType *AttrResType =
-              dyn_cast<HLSLAttributedResourceType>(FD->getType().getTypePtr()))
-        return AttrResType;
+FindHandleTypeOnResource(const Type *Ty) {
+  // If Ty is a resource class, the first field must
+  // be the resource handle of type HLSLAttributedResourceType
+  if (RecordDecl *RD = Ty->getAsCXXRecordDecl()) {
+    if (!RD->fields().empty()) {
+      const auto &FirstFD = RD->fields().begin();
+      return dyn_cast<HLSLAttributedResourceType>(
+          FirstFD->getType().getTypePtr());
     }
   }
   return nullptr;
 }
 
-// Iterate over RecordType fields and return true if any of them matched the
-// register type
-static bool ContainsResourceForRegisterType(Sema &S, const RecordType *RT,
-                                            RegisterType RegType) {
-  llvm::SmallVector<const Type *> TypesToScan;
-  TypesToScan.emplace_back(RT);
-
-  while (!TypesToScan.empty()) {
-    const Type *T = TypesToScan.pop_back_val();
-    while (T->isArrayType())
-      T = T->getArrayElementTypeNoTypeQual();
-    if (T->isIntegralOrEnumerationType() || T->isFloatingType()) {
-      if (RegType == RegisterType::C)
-        return true;
+// Returns handle type of a resource, if the VarDecl is a resource
+static const HLSLAttributedResourceType *
+FindHandleTypeOnResource(const VarDecl *VD) {
+  assert(VD != nullptr && "expected VarDecl");
+  return FindHandleTypeOnResource(VD->getType().getTypePtr());
+}
+
+// Walks though the global variable declaration, collects all resource binding
+// requirements and adds them to Bindings
+void SemaHLSL::FindResourcesOnUserRecordDecl(const VarDecl *VD,
+                                             const RecordType *RT, int Size) {
+  const RecordDecl *RD = RT->getDecl();
+  for (FieldDecl *FD : RD->fields()) {
+    const Type *Ty = FD->getType()->getUnqualifiedDesugaredType();
+
+    // Calculate array size and unwrap
+    int ArraySize = 1;
+    assert(!Ty->isIncompleteArrayType() &&
+           "incomplete arrays inside user defined types are not supported");
+    while (Ty->isConstantArrayType()) {
+      const ConstantArrayType *CAT = cast<ConstantArrayType>(Ty);
+      ArraySize *= CAT->getSize().getSExtValue();
+      Ty = CAT->getElementType()->getUnqualifiedDesugaredType();
     }
-    const RecordType *RT = T->getAs<RecordType>();
-    if (!RT)
+
+    if (!Ty->isRecordType())
       continue;
 
-    const RecordDecl *RD = RT->getDecl();
-    for (FieldDecl *FD : RD->fields()) {
-      const Type *FieldTy = FD->getType().getTypePtr();
-      if (const HLSLAttributedResourceType *AttrResType =
-              dyn_cast<HLSLAttributedResourceType>(FieldTy)) {
-        ResourceClass RC = AttrResType->getAttrs().ResourceClass;
-        if (getRegisterType(RC) == RegType)
-          return true;
-      } else {
-        TypesToScan.emplace_back(FD->getType().getTypePtr());
-      }
+    // Field is a resource or array of resources
+    if (const HLSLAttributedResourceType *AttrResType =
+            FindHandleTypeOnResource(Ty)) {
+      ResourceClass RC = AttrResType->getAttrs().ResourceClass;
+
+      // Add a new DeclBindingInfo to Bindings. Update the binding size if
+      // a binding info already exists (there are multiple resources of same
+      // resource class in this user decl)
+      if (auto *DBI = Bindings.getDeclBindingInfo(VD, RC))
+        DBI->Size += Size * ArraySize;
+      else
+        Bindings.addDeclBindingInfo(VD, RC, Size);
+    } else if (const RecordType *RT = dyn_cast<RecordType>(Ty)) {
+      // Recursively scan embedded struct or class; it would be nice to do this
+      // without recursion, but tricky to corrently calculate the size.
+      // Hopefully nesting of structs in structs too many levels is unlikely.
+      FindResourcesOnUserRecordDecl(VD, RT, Size);
     }
   }
-  return false;
 }
 
-static void CheckContainsResourceForRegisterType(Sema &S,
-                                                 SourceLocation &ArgLoc,
-                                                 Decl *D, RegisterType RegType,
-                                                 bool SpecifiedSpace) {
+// return false if the register binding is not valid
+static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc,
+                                         Decl *D, RegisterType RegType,
+                                         bool SpecifiedSpace) {
----------------
bogner wrote:

If we're going to add a comment here it should really say what the function does, not just what the return value means. 

https://github.com/llvm/llvm-project/pull/111203


More information about the cfe-commits mailing list