[clang] [HLSL] Collect explicit resource binding information (part 1) (PR #111203)

Helena Kotas via cfe-commits cfe-commits at lists.llvm.org
Fri Oct 4 12:57:23 PDT 2024


https://github.com/hekota created https://github.com/llvm/llvm-project/pull/111203

Adds fields to `HLSLResourceBindingAttr` to store processed binding information. This will be used by CodeGen or Sema for resource initialization or overlapping mapping diagnostic.

Moves binding checks for user defined types (UDTs) to `ProcessResourceBindingOnDecl` (called from ActOnVariableDeclarator), which updated the information in the attribute and where additional processing of the explicit resource binding will be added in the future.

Changed `handleResourceBindingAttr` to not create the resource binding attribute if the local binding diagnostic detects errors.

Part 1 of #110719

>From f545a14e11556c91d10b14617e3588fe5eae6d42 Mon Sep 17 00:00:00 2001
From: Helena Kotas <hekotas at microsoft.com>
Date: Fri, 4 Oct 2024 12:21:51 -0700
Subject: [PATCH] [HLSL] Collect explicit resource binding information (part 1)

- Do not create resource binding attribute if it is not valid
- Store basic resource binding information on HLSLResourceBindingAttr
- Move UDT type checking to to ActOnVariableDeclarator

Part 1 of #110719
---
 clang/include/clang/Basic/Attr.td             |  29 +++
 clang/include/clang/Sema/SemaHLSL.h           |   2 +
 clang/lib/Sema/SemaDecl.cpp                   |   3 +
 clang/lib/Sema/SemaHLSL.cpp                   | 227 ++++++++++++------
 .../resource_binding_attr_error_udt.hlsl      |   8 +-
 5 files changed, 188 insertions(+), 81 deletions(-)

diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index fbcbf0ed416416..668c599da81390 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4588,6 +4588,35 @@ def HLSLResourceBinding: InheritableAttr {
   let LangOpts = [HLSL];
   let Args = [StringArgument<"Slot">, StringArgument<"Space", 1>];
   let Documentation = [HLSLResourceBindingDocs];
+  let AdditionalMembers = [{
+      enum class RegisterType : unsigned { SRV, UAV, CBuffer, Sampler, C, I, Invalid };
+
+      const FieldDecl *ResourceField = nullptr;
+      RegisterType RegType;
+      unsigned SlotNumber;
+      unsigned SpaceNumber;
+
+      void setBinding(RegisterType RT, unsigned SlotNum, unsigned SpaceNum) {
+        RegType = RT;
+        SlotNumber = SlotNum;
+        SpaceNumber = SpaceNum;
+      }
+      void setResourceField(const FieldDecl *FD) {
+        ResourceField = FD;
+      }
+      const FieldDecl *getResourceField() {
+        return ResourceField;
+      }
+      RegisterType getRegisterType() {
+        return RegType;
+      }
+      unsigned getSlotNumber() {
+        return SlotNumber;
+      }
+      unsigned getSpaceNumber() {
+        return SpaceNumber;
+      }
+  }];
 }
 
 def HLSLPackOffset: HLSLAnnotationAttr {
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index fa957abc9791af..018e7ea5901a2b 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -28,6 +28,7 @@ class AttributeCommonInfo;
 class IdentifierInfo;
 class ParsedAttr;
 class Scope;
+class VarDecl;
 
 // FIXME: This can be hidden (as static function in SemaHLSL.cpp) once we no
 // longer need to create builtin buffer types in HLSLExternalSemaSource.
@@ -62,6 +63,7 @@ class SemaHLSL : public SemaBase {
       const Attr *A, llvm::Triple::EnvironmentType Stage,
       std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages);
   void DiagnoseAvailabilityViolations(TranslationUnitDecl *TU);
+  void ProcessResourceBindingOnDecl(VarDecl *D);
 
   QualType handleVectorBinOpConversion(ExprResult &LHS, ExprResult &RHS,
                                        QualType LHSType, QualType RHSType,
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index 2bf610746bc317..8e27a5e068e702 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -7876,6 +7876,9 @@ NamedDecl *Sema::ActOnVariableDeclarator(
   // Handle attributes prior to checking for duplicates in MergeVarDecl
   ProcessDeclAttributes(S, NewVD, D);
 
+  if (getLangOpts().HLSL)
+    HLSL().ProcessResourceBindingOnDecl(NewVD);
+
   // FIXME: This is probably the wrong location to be doing this and we should
   // probably be doing this for more attributes (especially for function
   // pointer attributes such as format, warn_unused_result, etc.). Ideally
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index fbcba201a351a6..568a8de30c1fc5 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -41,9 +41,7 @@
 
 using namespace clang;
 using llvm::dxil::ResourceClass;
-
-enum class RegisterType { SRV, UAV, CBuffer, Sampler, C, I, Invalid };
-
+using RegisterType = HLSLResourceBindingAttr::RegisterType;
 static RegisterType getRegisterType(ResourceClass RC) {
   switch (RC) {
   case ResourceClass::SRV:
@@ -985,44 +983,43 @@ SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) {
   return LocInfo;
 }
 
-// get the record decl from a var decl that we expect
-// represents a resource
-static CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *VD) {
-  const Type *Ty = VD->getType()->getPointeeOrArrayElementType();
-  assert(Ty && "Resource must have an element type.");
-
-  if (Ty->isBuiltinType())
-    return nullptr;
-
-  CXXRecordDecl *TheRecordDecl = Ty->getAsCXXRecordDecl();
-  assert(TheRecordDecl && "Resource should have a resource type declaration.");
-  return TheRecordDecl;
-}
-
+// Returns handle type of a resource, if the VarDecl is a resource
+// or an array of resources
 static const HLSLAttributedResourceType *
-findAttributedResourceTypeOnField(VarDecl *VD) {
+FindHandleTypeOnResource(const VarDecl *VD) {
+  // If VarDecl is a resource class, the first field must
+  // be the resource handle of type HLSLAttributedResourceType
   assert(VD != nullptr && "expected VarDecl");
-  if (RecordDecl *RD = getRecordDeclFromVarDecl(VD)) {
-    for (auto *FD : RD->fields()) {
-      if (const HLSLAttributedResourceType *AttrResType =
-              dyn_cast<HLSLAttributedResourceType>(FD->getType().getTypePtr()))
-        return AttrResType;
+  const Type *Ty = VD->getType()->getPointeeOrArrayElementType();
+  if (RecordDecl *RD = Ty->getAsCXXRecordDecl()) {
+    if (!RD->fields().empty()) {
+      const auto &FirstFD = RD->fields().begin();
+      return dyn_cast<HLSLAttributedResourceType>(
+          FirstFD->getType().getTypePtr());
     }
   }
   return nullptr;
 }
 
-// Iterate over RecordType fields and return true if any of them matched the
-// register type
-static bool ContainsResourceForRegisterType(Sema &S, const RecordType *RT,
-                                            RegisterType RegType) {
+// Walks though the user defined record type, finds resource class
+// that matches the RegisterBinding.Type and assigns it to
+// RegisterBinding::Decl.
+static bool
+ProcessResourceBindingOnUserRecordDecl(const RecordType *RT,
+                                       HLSLResourceBindingAttr *RBA) {
+
   llvm::SmallVector<const Type *> TypesToScan;
   TypesToScan.emplace_back(RT);
+  RegisterType RegType = RBA->getRegisterType();
 
   while (!TypesToScan.empty()) {
     const Type *T = TypesToScan.pop_back_val();
-    while (T->isArrayType())
+
+    while (T->isArrayType()) {
+      // FIXME: calculate the binding size from the array dimensions (or
+      // unbounded for unsized array) size *= (size_of_array);
       T = T->getArrayElementTypeNoTypeQual();
+    }
     if (T->isIntegralOrEnumerationType() || T->isFloatingType()) {
       if (RegType == RegisterType::C)
         return true;
@@ -1037,8 +1034,12 @@ static bool ContainsResourceForRegisterType(Sema &S, const RecordType *RT,
       if (const HLSLAttributedResourceType *AttrResType =
               dyn_cast<HLSLAttributedResourceType>(FieldTy)) {
         ResourceClass RC = AttrResType->getAttrs().ResourceClass;
-        if (getRegisterType(RC) == RegType)
+        if (getRegisterType(RC) == RegType) {
+          assert(RBA->getResourceField() == nullptr &&
+                 "multiple register bindings of the same type are not allowed");
+          RBA->setResourceField(FD);
           return true;
+        }
       } else {
         TypesToScan.emplace_back(FD->getType().getTypePtr());
       }
@@ -1047,26 +1048,28 @@ static bool ContainsResourceForRegisterType(Sema &S, const RecordType *RT,
   return false;
 }
 
-static void CheckContainsResourceForRegisterType(Sema &S,
-                                                 SourceLocation &ArgLoc,
-                                                 Decl *D, RegisterType RegType,
-                                                 bool SpecifiedSpace) {
+// return false if the register binding is not valid
+static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc,
+                                         Decl *D, RegisterType RegType,
+                                         bool SpecifiedSpace) {
   int RegTypeNum = static_cast<int>(RegType);
 
   // check if the decl type is groupshared
   if (D->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) {
     S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
-    return;
+    return false;
   }
 
   // Cbuffers and Tbuffers are HLSLBufferDecl types
   if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) {
     ResourceClass RC = CBufferOrTBuffer->isCBuffer() ? ResourceClass::CBuffer
                                                      : ResourceClass::SRV;
-    if (RegType != getRegisterType(RC))
-      S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
-          << RegTypeNum;
-    return;
+    if (RegType == getRegisterType(RC))
+      return true;
+
+    S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
+        << RegTypeNum;
+    return false;
   }
 
   // Samplers, UAVs, and SRVs are VarDecl types
@@ -1075,11 +1078,13 @@ static void CheckContainsResourceForRegisterType(Sema &S,
 
   // Resource
   if (const HLSLAttributedResourceType *AttrResType =
-          findAttributedResourceTypeOnField(VD)) {
-    if (RegType != getRegisterType(AttrResType->getAttrs().ResourceClass))
-      S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
-          << RegTypeNum;
-    return;
+          FindHandleTypeOnResource(VD)) {
+    if (RegType == getRegisterType(AttrResType->getAttrs().ResourceClass))
+      return true;
+
+    S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
+        << RegTypeNum;
+    return false;
   }
 
   const clang::Type *Ty = VD->getType().getTypePtr();
@@ -1088,36 +1093,43 @@ static void CheckContainsResourceForRegisterType(Sema &S,
 
   // Basic types
   if (Ty->isArithmeticType()) {
+    bool IsValid = true;
     bool DeclaredInCOrTBuffer = isa<HLSLBufferDecl>(D->getDeclContext());
-    if (SpecifiedSpace && !DeclaredInCOrTBuffer)
+    if (SpecifiedSpace && !DeclaredInCOrTBuffer) {
       S.Diag(ArgLoc, diag::err_hlsl_space_on_global_constant);
+      IsValid = false;
+    }
 
     if (!DeclaredInCOrTBuffer &&
         (Ty->isIntegralType(S.getASTContext()) || Ty->isFloatingType())) {
       // Default Globals
       if (RegType == RegisterType::CBuffer)
         S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b);
-      else if (RegType != RegisterType::C)
+      else if (RegType != RegisterType::C) {
         S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
+        IsValid = false;
+      }
     } else {
       if (RegType == RegisterType::C)
         S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset);
-      else
+      else {
         S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
+        IsValid = false;
+      }
     }
-  } else if (Ty->isRecordType()) {
-    // Class/struct types - walk the declaration and check each field and
-    // subclass
-    if (!ContainsResourceForRegisterType(S, Ty->getAs<RecordType>(), RegType))
-      S.Diag(D->getLocation(), diag::warn_hlsl_user_defined_type_missing_member)
-          << RegTypeNum;
-  } else {
-    // Anything else is an error
-    S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
+    return IsValid;
   }
+  if (Ty->isRecordType())
+    // RecordTypes will be diagnosed in ProcessResourceBindingOnDecl
+    // that is called from ActOnVariableDeclarator
+    return true;
+
+  // Anything else is an error
+  S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
+  return false;
 }
 
-static void ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
+static bool ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
                                                 RegisterType regType) {
   // make sure that there are no two register annotations
   // applied to the decl with the same register type
@@ -1135,21 +1147,19 @@ static void ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
 
       RegisterType otherRegType = getRegisterType(attr->getSlot());
       if (RegisterTypesDetected[static_cast<int>(otherRegType)]) {
-        if (PreviousConflicts[TheDecl].count(otherRegType))
-          continue;
         int otherRegTypeNum = static_cast<int>(otherRegType);
         S.Diag(TheDecl->getLocation(),
                diag::err_hlsl_duplicate_register_annotation)
             << otherRegTypeNum;
-        PreviousConflicts[TheDecl].insert(otherRegType);
-      } else {
-        RegisterTypesDetected[static_cast<int>(otherRegType)] = true;
+        return false;
       }
+      RegisterTypesDetected[static_cast<int>(otherRegType)] = true;
     }
   }
+  return true;
 }
 
-static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
+static bool DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
                                           Decl *D, RegisterType RegType,
                                           bool SpecifiedSpace) {
 
@@ -1159,10 +1169,11 @@ static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
          "expecting VarDecl or HLSLBufferDecl");
 
   // check if the declaration contains resource matching the register type
-  CheckContainsResourceForRegisterType(S, ArgLoc, D, RegType, SpecifiedSpace);
+  if (!DiagnoseLocalRegisterBinding(S, ArgLoc, D, RegType, SpecifiedSpace))
+    return false;
 
   // next, if multiple register annotations exist, check that none conflict.
-  ValidateMultipleRegisterAnnotations(S, D, RegType);
+  return ValidateMultipleRegisterAnnotations(S, D, RegType);
 }
 
 void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {
@@ -1203,23 +1214,24 @@ void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {
     Slot = Str;
   }
 
-  RegisterType regType;
+  RegisterType RegType;
+  unsigned SlotNum = 0;
+  unsigned SpaceNum = 0;
 
   // Validate.
   if (!Slot.empty()) {
-    regType = getRegisterType(Slot);
-    if (regType == RegisterType::I) {
+    RegType = getRegisterType(Slot);
+    if (RegType == RegisterType::I) {
       Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_i);
       return;
     }
-    if (regType == RegisterType::Invalid) {
+    if (RegType == RegisterType::Invalid) {
       Diag(ArgLoc, diag::err_hlsl_binding_type_invalid) << Slot.substr(0, 1);
       return;
     }
 
-    StringRef SlotNum = Slot.substr(1);
-    unsigned Num = 0;
-    if (SlotNum.getAsInteger(10, Num)) {
+    StringRef SlotNumStr = Slot.substr(1);
+    if (SlotNumStr.getAsInteger(10, SlotNum)) {
       Diag(ArgLoc, diag::err_hlsl_unsupported_register_number);
       return;
     }
@@ -1229,20 +1241,22 @@ void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {
     Diag(SpaceArgLoc, diag::err_hlsl_expected_space) << Space;
     return;
   }
-  StringRef SpaceNum = Space.substr(5);
-  unsigned Num = 0;
-  if (SpaceNum.getAsInteger(10, Num)) {
+  StringRef SpaceNumStr = Space.substr(5);
+  if (SpaceNumStr.getAsInteger(10, SpaceNum)) {
     Diag(SpaceArgLoc, diag::err_hlsl_expected_space) << Space;
     return;
   }
 
-  DiagnoseHLSLRegisterAttribute(SemaRef, ArgLoc, TheDecl, regType,
-                                SpecifiedSpace);
+  if (!DiagnoseHLSLRegisterAttribute(SemaRef, ArgLoc, TheDecl, RegType,
+                                     SpecifiedSpace))
+    return;
 
   HLSLResourceBindingAttr *NewAttr =
       HLSLResourceBindingAttr::Create(getASTContext(), Slot, Space, AL);
-  if (NewAttr)
+  if (NewAttr) {
+    NewAttr->setBinding(RegType, SlotNum, SpaceNum);
     TheDecl->addAttr(NewAttr);
+  }
 }
 
 void SemaHLSL::handleParamModifierAttr(Decl *D, const ParsedAttr &AL) {
@@ -2228,3 +2242,62 @@ QualType SemaHLSL::getInoutParameterType(QualType Ty) {
   Ty.addRestrict();
   return Ty;
 }
+
+// Walks though existing explicit bindings, finds the actual resource class
+// decl the binding applies to and sets it to attr->ResourceField.
+// Additional processing of resource binding can be added here later on,
+// such as preparation for overapping resource detection or implicit binding.
+void SemaHLSL::ProcessResourceBindingOnDecl(VarDecl *D) {
+  if (!D->hasGlobalStorage())
+    return;
+  
+  for (Attr *A : D->attrs()) {
+    HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(A);
+    if (!RBA)
+      continue;
+
+    // // Cbuffers and Tbuffers are HLSLBufferDecl types
+    if (const HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) {
+      assert(RBA->getRegisterType() ==
+                 getRegisterType(CBufferOrTBuffer->isCBuffer()
+                                     ? ResourceClass::CBuffer
+                                     : ResourceClass::SRV) &&
+             "this should have been handled in DiagnoseLocalRegisterBinding");
+      // should we handle HLSLBufferDecl here?
+      continue;
+    }
+
+    // Samplers, UAVs, and SRVs are VarDecl types
+    assert(isa<VarDecl>(D) && "D is expected to be VarDecl or HLSLBufferDecl");
+    const VarDecl *VD = cast<VarDecl>(D);
+
+    // Register binding directly on global resource class variable
+    if (const HLSLAttributedResourceType *AttrResType =
+            FindHandleTypeOnResource(VD)) {
+      // FIXME: if array, calculate the binding size from the array dimensions
+      // (or unbounded for unsized array)
+      assert(RBA->getResourceField() == nullptr);
+      continue;
+    }
+
+    // Global array
+    const clang::Type *Ty = VD->getType().getTypePtr();
+    while (Ty->isArrayType()) {
+      Ty = Ty->getArrayElementTypeNoTypeQual();
+    }
+
+    // Basic types
+    if (Ty->isArithmeticType()) {
+      continue;
+    }
+
+    if (Ty->isRecordType()) {
+      if (!ProcessResourceBindingOnUserRecordDecl(Ty->getAs<RecordType>(),
+                                                  RBA)) {
+        SemaRef.Diag(D->getLocation(),
+                     diag::warn_hlsl_user_defined_type_missing_member)
+            << static_cast<int>(RBA->getRegisterType());
+      }
+    }
+  }
+}
diff --git a/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl b/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl
index ea2d576e4cca55..40517f393e1284 100644
--- a/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl
+++ b/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl
@@ -106,7 +106,6 @@ struct Eg12{
   MySRV s1;
   MySRV s2;
 };
-// expected-warning at +3{{binding type 'u' only applies to types containing UAV resources}}
 // expected-warning at +2{{binding type 'u' only applies to types containing UAV resources}}
 // expected-error at +1{{binding type 'u' cannot be applied more than once}}
 Eg12 e12 : register(u9) : register(u10);
@@ -115,12 +114,14 @@ struct Eg13{
   MySRV s1;
   MySRV s2;
 };
-// expected-warning at +4{{binding type 'u' only applies to types containing UAV resources}}
 // expected-warning at +3{{binding type 'u' only applies to types containing UAV resources}}
-// expected-warning at +2{{binding type 'u' only applies to types containing UAV resources}}
+// expected-error at +2{{binding type 'u' cannot be applied more than once}}
 // expected-error at +1{{binding type 'u' cannot be applied more than once}}
 Eg13 e13 : register(u9) : register(u10) : register(u11);
 
+// expected-error at +1{{binding type 't' cannot be applied more than once}}
+Eg13 e13_2 : register(t11) : register(t12);
+
 struct Eg14{
  MyTemplatedUAV<int> r1;  
 };
@@ -132,4 +133,3 @@ struct Eg15 {
 }; 
 // expected no error
 Eg15 e15 : register(c0);
-



More information about the cfe-commits mailing list