[clang] [HLSL] Collect explicit resource binding information (PR #111203)

Helena Kotas via cfe-commits cfe-commits at lists.llvm.org
Wed Oct 16 10:15:21 PDT 2024


https://github.com/hekota updated https://github.com/llvm/llvm-project/pull/111203

>From f545a14e11556c91d10b14617e3588fe5eae6d42 Mon Sep 17 00:00:00 2001
From: Helena Kotas <hekotas at microsoft.com>
Date: Fri, 4 Oct 2024 12:21:51 -0700
Subject: [PATCH 1/6] [HLSL] Collect explicit resource binding information
 (part 1)

- Do not create resource binding attribute if it is not valid
- Store basic resource binding information on HLSLResourceBindingAttr
- Move UDT type checking to to ActOnVariableDeclarator

Part 1 of #110719
---
 clang/include/clang/Basic/Attr.td             |  29 +++
 clang/include/clang/Sema/SemaHLSL.h           |   2 +
 clang/lib/Sema/SemaDecl.cpp                   |   3 +
 clang/lib/Sema/SemaHLSL.cpp                   | 227 ++++++++++++------
 .../resource_binding_attr_error_udt.hlsl      |   8 +-
 5 files changed, 188 insertions(+), 81 deletions(-)

diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index fbcbf0ed416416..668c599da81390 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4588,6 +4588,35 @@ def HLSLResourceBinding: InheritableAttr {
   let LangOpts = [HLSL];
   let Args = [StringArgument<"Slot">, StringArgument<"Space", 1>];
   let Documentation = [HLSLResourceBindingDocs];
+  let AdditionalMembers = [{
+      enum class RegisterType : unsigned { SRV, UAV, CBuffer, Sampler, C, I, Invalid };
+
+      const FieldDecl *ResourceField = nullptr;
+      RegisterType RegType;
+      unsigned SlotNumber;
+      unsigned SpaceNumber;
+
+      void setBinding(RegisterType RT, unsigned SlotNum, unsigned SpaceNum) {
+        RegType = RT;
+        SlotNumber = SlotNum;
+        SpaceNumber = SpaceNum;
+      }
+      void setResourceField(const FieldDecl *FD) {
+        ResourceField = FD;
+      }
+      const FieldDecl *getResourceField() {
+        return ResourceField;
+      }
+      RegisterType getRegisterType() {
+        return RegType;
+      }
+      unsigned getSlotNumber() {
+        return SlotNumber;
+      }
+      unsigned getSpaceNumber() {
+        return SpaceNumber;
+      }
+  }];
 }
 
 def HLSLPackOffset: HLSLAnnotationAttr {
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index fa957abc9791af..018e7ea5901a2b 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -28,6 +28,7 @@ class AttributeCommonInfo;
 class IdentifierInfo;
 class ParsedAttr;
 class Scope;
+class VarDecl;
 
 // FIXME: This can be hidden (as static function in SemaHLSL.cpp) once we no
 // longer need to create builtin buffer types in HLSLExternalSemaSource.
@@ -62,6 +63,7 @@ class SemaHLSL : public SemaBase {
       const Attr *A, llvm::Triple::EnvironmentType Stage,
       std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages);
   void DiagnoseAvailabilityViolations(TranslationUnitDecl *TU);
+  void ProcessResourceBindingOnDecl(VarDecl *D);
 
   QualType handleVectorBinOpConversion(ExprResult &LHS, ExprResult &RHS,
                                        QualType LHSType, QualType RHSType,
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index 2bf610746bc317..8e27a5e068e702 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -7876,6 +7876,9 @@ NamedDecl *Sema::ActOnVariableDeclarator(
   // Handle attributes prior to checking for duplicates in MergeVarDecl
   ProcessDeclAttributes(S, NewVD, D);
 
+  if (getLangOpts().HLSL)
+    HLSL().ProcessResourceBindingOnDecl(NewVD);
+
   // FIXME: This is probably the wrong location to be doing this and we should
   // probably be doing this for more attributes (especially for function
   // pointer attributes such as format, warn_unused_result, etc.). Ideally
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index fbcba201a351a6..568a8de30c1fc5 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -41,9 +41,7 @@
 
 using namespace clang;
 using llvm::dxil::ResourceClass;
-
-enum class RegisterType { SRV, UAV, CBuffer, Sampler, C, I, Invalid };
-
+using RegisterType = HLSLResourceBindingAttr::RegisterType;
 static RegisterType getRegisterType(ResourceClass RC) {
   switch (RC) {
   case ResourceClass::SRV:
@@ -985,44 +983,43 @@ SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) {
   return LocInfo;
 }
 
-// get the record decl from a var decl that we expect
-// represents a resource
-static CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *VD) {
-  const Type *Ty = VD->getType()->getPointeeOrArrayElementType();
-  assert(Ty && "Resource must have an element type.");
-
-  if (Ty->isBuiltinType())
-    return nullptr;
-
-  CXXRecordDecl *TheRecordDecl = Ty->getAsCXXRecordDecl();
-  assert(TheRecordDecl && "Resource should have a resource type declaration.");
-  return TheRecordDecl;
-}
-
+// Returns handle type of a resource, if the VarDecl is a resource
+// or an array of resources
 static const HLSLAttributedResourceType *
-findAttributedResourceTypeOnField(VarDecl *VD) {
+FindHandleTypeOnResource(const VarDecl *VD) {
+  // If VarDecl is a resource class, the first field must
+  // be the resource handle of type HLSLAttributedResourceType
   assert(VD != nullptr && "expected VarDecl");
-  if (RecordDecl *RD = getRecordDeclFromVarDecl(VD)) {
-    for (auto *FD : RD->fields()) {
-      if (const HLSLAttributedResourceType *AttrResType =
-              dyn_cast<HLSLAttributedResourceType>(FD->getType().getTypePtr()))
-        return AttrResType;
+  const Type *Ty = VD->getType()->getPointeeOrArrayElementType();
+  if (RecordDecl *RD = Ty->getAsCXXRecordDecl()) {
+    if (!RD->fields().empty()) {
+      const auto &FirstFD = RD->fields().begin();
+      return dyn_cast<HLSLAttributedResourceType>(
+          FirstFD->getType().getTypePtr());
     }
   }
   return nullptr;
 }
 
-// Iterate over RecordType fields and return true if any of them matched the
-// register type
-static bool ContainsResourceForRegisterType(Sema &S, const RecordType *RT,
-                                            RegisterType RegType) {
+// Walks though the user defined record type, finds resource class
+// that matches the RegisterBinding.Type and assigns it to
+// RegisterBinding::Decl.
+static bool
+ProcessResourceBindingOnUserRecordDecl(const RecordType *RT,
+                                       HLSLResourceBindingAttr *RBA) {
+
   llvm::SmallVector<const Type *> TypesToScan;
   TypesToScan.emplace_back(RT);
+  RegisterType RegType = RBA->getRegisterType();
 
   while (!TypesToScan.empty()) {
     const Type *T = TypesToScan.pop_back_val();
-    while (T->isArrayType())
+
+    while (T->isArrayType()) {
+      // FIXME: calculate the binding size from the array dimensions (or
+      // unbounded for unsized array) size *= (size_of_array);
       T = T->getArrayElementTypeNoTypeQual();
+    }
     if (T->isIntegralOrEnumerationType() || T->isFloatingType()) {
       if (RegType == RegisterType::C)
         return true;
@@ -1037,8 +1034,12 @@ static bool ContainsResourceForRegisterType(Sema &S, const RecordType *RT,
       if (const HLSLAttributedResourceType *AttrResType =
               dyn_cast<HLSLAttributedResourceType>(FieldTy)) {
         ResourceClass RC = AttrResType->getAttrs().ResourceClass;
-        if (getRegisterType(RC) == RegType)
+        if (getRegisterType(RC) == RegType) {
+          assert(RBA->getResourceField() == nullptr &&
+                 "multiple register bindings of the same type are not allowed");
+          RBA->setResourceField(FD);
           return true;
+        }
       } else {
         TypesToScan.emplace_back(FD->getType().getTypePtr());
       }
@@ -1047,26 +1048,28 @@ static bool ContainsResourceForRegisterType(Sema &S, const RecordType *RT,
   return false;
 }
 
-static void CheckContainsResourceForRegisterType(Sema &S,
-                                                 SourceLocation &ArgLoc,
-                                                 Decl *D, RegisterType RegType,
-                                                 bool SpecifiedSpace) {
+// return false if the register binding is not valid
+static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc,
+                                         Decl *D, RegisterType RegType,
+                                         bool SpecifiedSpace) {
   int RegTypeNum = static_cast<int>(RegType);
 
   // check if the decl type is groupshared
   if (D->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) {
     S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
-    return;
+    return false;
   }
 
   // Cbuffers and Tbuffers are HLSLBufferDecl types
   if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) {
     ResourceClass RC = CBufferOrTBuffer->isCBuffer() ? ResourceClass::CBuffer
                                                      : ResourceClass::SRV;
-    if (RegType != getRegisterType(RC))
-      S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
-          << RegTypeNum;
-    return;
+    if (RegType == getRegisterType(RC))
+      return true;
+
+    S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
+        << RegTypeNum;
+    return false;
   }
 
   // Samplers, UAVs, and SRVs are VarDecl types
@@ -1075,11 +1078,13 @@ static void CheckContainsResourceForRegisterType(Sema &S,
 
   // Resource
   if (const HLSLAttributedResourceType *AttrResType =
-          findAttributedResourceTypeOnField(VD)) {
-    if (RegType != getRegisterType(AttrResType->getAttrs().ResourceClass))
-      S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
-          << RegTypeNum;
-    return;
+          FindHandleTypeOnResource(VD)) {
+    if (RegType == getRegisterType(AttrResType->getAttrs().ResourceClass))
+      return true;
+
+    S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
+        << RegTypeNum;
+    return false;
   }
 
   const clang::Type *Ty = VD->getType().getTypePtr();
@@ -1088,36 +1093,43 @@ static void CheckContainsResourceForRegisterType(Sema &S,
 
   // Basic types
   if (Ty->isArithmeticType()) {
+    bool IsValid = true;
     bool DeclaredInCOrTBuffer = isa<HLSLBufferDecl>(D->getDeclContext());
-    if (SpecifiedSpace && !DeclaredInCOrTBuffer)
+    if (SpecifiedSpace && !DeclaredInCOrTBuffer) {
       S.Diag(ArgLoc, diag::err_hlsl_space_on_global_constant);
+      IsValid = false;
+    }
 
     if (!DeclaredInCOrTBuffer &&
         (Ty->isIntegralType(S.getASTContext()) || Ty->isFloatingType())) {
       // Default Globals
       if (RegType == RegisterType::CBuffer)
         S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b);
-      else if (RegType != RegisterType::C)
+      else if (RegType != RegisterType::C) {
         S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
+        IsValid = false;
+      }
     } else {
       if (RegType == RegisterType::C)
         S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset);
-      else
+      else {
         S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
+        IsValid = false;
+      }
     }
-  } else if (Ty->isRecordType()) {
-    // Class/struct types - walk the declaration and check each field and
-    // subclass
-    if (!ContainsResourceForRegisterType(S, Ty->getAs<RecordType>(), RegType))
-      S.Diag(D->getLocation(), diag::warn_hlsl_user_defined_type_missing_member)
-          << RegTypeNum;
-  } else {
-    // Anything else is an error
-    S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
+    return IsValid;
   }
+  if (Ty->isRecordType())
+    // RecordTypes will be diagnosed in ProcessResourceBindingOnDecl
+    // that is called from ActOnVariableDeclarator
+    return true;
+
+  // Anything else is an error
+  S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
+  return false;
 }
 
-static void ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
+static bool ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
                                                 RegisterType regType) {
   // make sure that there are no two register annotations
   // applied to the decl with the same register type
@@ -1135,21 +1147,19 @@ static void ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
 
       RegisterType otherRegType = getRegisterType(attr->getSlot());
       if (RegisterTypesDetected[static_cast<int>(otherRegType)]) {
-        if (PreviousConflicts[TheDecl].count(otherRegType))
-          continue;
         int otherRegTypeNum = static_cast<int>(otherRegType);
         S.Diag(TheDecl->getLocation(),
                diag::err_hlsl_duplicate_register_annotation)
             << otherRegTypeNum;
-        PreviousConflicts[TheDecl].insert(otherRegType);
-      } else {
-        RegisterTypesDetected[static_cast<int>(otherRegType)] = true;
+        return false;
       }
+      RegisterTypesDetected[static_cast<int>(otherRegType)] = true;
     }
   }
+  return true;
 }
 
-static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
+static bool DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
                                           Decl *D, RegisterType RegType,
                                           bool SpecifiedSpace) {
 
@@ -1159,10 +1169,11 @@ static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
          "expecting VarDecl or HLSLBufferDecl");
 
   // check if the declaration contains resource matching the register type
-  CheckContainsResourceForRegisterType(S, ArgLoc, D, RegType, SpecifiedSpace);
+  if (!DiagnoseLocalRegisterBinding(S, ArgLoc, D, RegType, SpecifiedSpace))
+    return false;
 
   // next, if multiple register annotations exist, check that none conflict.
-  ValidateMultipleRegisterAnnotations(S, D, RegType);
+  return ValidateMultipleRegisterAnnotations(S, D, RegType);
 }
 
 void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {
@@ -1203,23 +1214,24 @@ void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {
     Slot = Str;
   }
 
-  RegisterType regType;
+  RegisterType RegType;
+  unsigned SlotNum = 0;
+  unsigned SpaceNum = 0;
 
   // Validate.
   if (!Slot.empty()) {
-    regType = getRegisterType(Slot);
-    if (regType == RegisterType::I) {
+    RegType = getRegisterType(Slot);
+    if (RegType == RegisterType::I) {
       Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_i);
       return;
     }
-    if (regType == RegisterType::Invalid) {
+    if (RegType == RegisterType::Invalid) {
       Diag(ArgLoc, diag::err_hlsl_binding_type_invalid) << Slot.substr(0, 1);
       return;
     }
 
-    StringRef SlotNum = Slot.substr(1);
-    unsigned Num = 0;
-    if (SlotNum.getAsInteger(10, Num)) {
+    StringRef SlotNumStr = Slot.substr(1);
+    if (SlotNumStr.getAsInteger(10, SlotNum)) {
       Diag(ArgLoc, diag::err_hlsl_unsupported_register_number);
       return;
     }
@@ -1229,20 +1241,22 @@ void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {
     Diag(SpaceArgLoc, diag::err_hlsl_expected_space) << Space;
     return;
   }
-  StringRef SpaceNum = Space.substr(5);
-  unsigned Num = 0;
-  if (SpaceNum.getAsInteger(10, Num)) {
+  StringRef SpaceNumStr = Space.substr(5);
+  if (SpaceNumStr.getAsInteger(10, SpaceNum)) {
     Diag(SpaceArgLoc, diag::err_hlsl_expected_space) << Space;
     return;
   }
 
-  DiagnoseHLSLRegisterAttribute(SemaRef, ArgLoc, TheDecl, regType,
-                                SpecifiedSpace);
+  if (!DiagnoseHLSLRegisterAttribute(SemaRef, ArgLoc, TheDecl, RegType,
+                                     SpecifiedSpace))
+    return;
 
   HLSLResourceBindingAttr *NewAttr =
       HLSLResourceBindingAttr::Create(getASTContext(), Slot, Space, AL);
-  if (NewAttr)
+  if (NewAttr) {
+    NewAttr->setBinding(RegType, SlotNum, SpaceNum);
     TheDecl->addAttr(NewAttr);
+  }
 }
 
 void SemaHLSL::handleParamModifierAttr(Decl *D, const ParsedAttr &AL) {
@@ -2228,3 +2242,62 @@ QualType SemaHLSL::getInoutParameterType(QualType Ty) {
   Ty.addRestrict();
   return Ty;
 }
+
+// Walks though existing explicit bindings, finds the actual resource class
+// decl the binding applies to and sets it to attr->ResourceField.
+// Additional processing of resource binding can be added here later on,
+// such as preparation for overapping resource detection or implicit binding.
+void SemaHLSL::ProcessResourceBindingOnDecl(VarDecl *D) {
+  if (!D->hasGlobalStorage())
+    return;
+  
+  for (Attr *A : D->attrs()) {
+    HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(A);
+    if (!RBA)
+      continue;
+
+    // // Cbuffers and Tbuffers are HLSLBufferDecl types
+    if (const HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) {
+      assert(RBA->getRegisterType() ==
+                 getRegisterType(CBufferOrTBuffer->isCBuffer()
+                                     ? ResourceClass::CBuffer
+                                     : ResourceClass::SRV) &&
+             "this should have been handled in DiagnoseLocalRegisterBinding");
+      // should we handle HLSLBufferDecl here?
+      continue;
+    }
+
+    // Samplers, UAVs, and SRVs are VarDecl types
+    assert(isa<VarDecl>(D) && "D is expected to be VarDecl or HLSLBufferDecl");
+    const VarDecl *VD = cast<VarDecl>(D);
+
+    // Register binding directly on global resource class variable
+    if (const HLSLAttributedResourceType *AttrResType =
+            FindHandleTypeOnResource(VD)) {
+      // FIXME: if array, calculate the binding size from the array dimensions
+      // (or unbounded for unsized array)
+      assert(RBA->getResourceField() == nullptr);
+      continue;
+    }
+
+    // Global array
+    const clang::Type *Ty = VD->getType().getTypePtr();
+    while (Ty->isArrayType()) {
+      Ty = Ty->getArrayElementTypeNoTypeQual();
+    }
+
+    // Basic types
+    if (Ty->isArithmeticType()) {
+      continue;
+    }
+
+    if (Ty->isRecordType()) {
+      if (!ProcessResourceBindingOnUserRecordDecl(Ty->getAs<RecordType>(),
+                                                  RBA)) {
+        SemaRef.Diag(D->getLocation(),
+                     diag::warn_hlsl_user_defined_type_missing_member)
+            << static_cast<int>(RBA->getRegisterType());
+      }
+    }
+  }
+}
diff --git a/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl b/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl
index ea2d576e4cca55..40517f393e1284 100644
--- a/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl
+++ b/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl
@@ -106,7 +106,6 @@ struct Eg12{
   MySRV s1;
   MySRV s2;
 };
-// expected-warning at +3{{binding type 'u' only applies to types containing UAV resources}}
 // expected-warning at +2{{binding type 'u' only applies to types containing UAV resources}}
 // expected-error at +1{{binding type 'u' cannot be applied more than once}}
 Eg12 e12 : register(u9) : register(u10);
@@ -115,12 +114,14 @@ struct Eg13{
   MySRV s1;
   MySRV s2;
 };
-// expected-warning at +4{{binding type 'u' only applies to types containing UAV resources}}
 // expected-warning at +3{{binding type 'u' only applies to types containing UAV resources}}
-// expected-warning at +2{{binding type 'u' only applies to types containing UAV resources}}
+// expected-error at +2{{binding type 'u' cannot be applied more than once}}
 // expected-error at +1{{binding type 'u' cannot be applied more than once}}
 Eg13 e13 : register(u9) : register(u10) : register(u11);
 
+// expected-error at +1{{binding type 't' cannot be applied more than once}}
+Eg13 e13_2 : register(t11) : register(t12);
+
 struct Eg14{
  MyTemplatedUAV<int> r1;  
 };
@@ -132,4 +133,3 @@ struct Eg15 {
 }; 
 // expected no error
 Eg15 e15 : register(c0);
-

>From a6c06943ce5df79e6765e12874c96c907b20d030 Mon Sep 17 00:00:00 2001
From: Helena Kotas <hekotas at microsoft.com>
Date: Fri, 4 Oct 2024 13:52:47 -0700
Subject: [PATCH 2/6] clang-format

---
 clang/lib/Sema/SemaHLSL.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 568a8de30c1fc5..5c27a74a853bba 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2250,7 +2250,7 @@ QualType SemaHLSL::getInoutParameterType(QualType Ty) {
 void SemaHLSL::ProcessResourceBindingOnDecl(VarDecl *D) {
   if (!D->hasGlobalStorage())
     return;
-  
+
   for (Attr *A : D->attrs()) {
     HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(A);
     if (!RBA)

>From a6a52327bef4325a00a2b8a1715b8b5b1315994f Mon Sep 17 00:00:00 2001
From: Helena Kotas <hekotas at microsoft.com>
Date: Wed, 9 Oct 2024 16:34:06 -0700
Subject: [PATCH 3/6] Collect all resource binding requirements and analyze
 explicit bindings based on that Also adds bindings size calculation and
 removed ResourceDecl field from HLSLResourceBindingAttr.

---
 clang/include/clang/Basic/Attr.td   |  25 ++-
 clang/include/clang/Sema/SemaHLSL.h |  59 +++++-
 clang/lib/Sema/SemaDecl.cpp         |   2 +-
 clang/lib/Sema/SemaHLSL.cpp         | 276 ++++++++++++++++++----------
 4 files changed, 256 insertions(+), 106 deletions(-)

diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index 668c599da81390..3997ffe78fbf96 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4591,22 +4591,20 @@ def HLSLResourceBinding: InheritableAttr {
   let AdditionalMembers = [{
       enum class RegisterType : unsigned { SRV, UAV, CBuffer, Sampler, C, I, Invalid };
 
-      const FieldDecl *ResourceField = nullptr;
       RegisterType RegType;
       unsigned SlotNumber;
       unsigned SpaceNumber;
+      
+      // Size of the binding
+      // 0 == not set
+      //-1 == unbounded
+      int Size;
 
-      void setBinding(RegisterType RT, unsigned SlotNum, unsigned SpaceNum) {
+      void setBinding(RegisterType RT, unsigned SlotNum, unsigned SpaceNum, int Size = 0) {
         RegType = RT;
         SlotNumber = SlotNum;
         SpaceNumber = SpaceNum;
       }
-      void setResourceField(const FieldDecl *FD) {
-        ResourceField = FD;
-      }
-      const FieldDecl *getResourceField() {
-        return ResourceField;
-      }
       RegisterType getRegisterType() {
         return RegType;
       }
@@ -4616,6 +4614,17 @@ def HLSLResourceBinding: InheritableAttr {
       unsigned getSpaceNumber() {
         return SpaceNumber;
       }
+      unsigned getSize() {
+        assert(Size == -1 || Size > 0 && "size not set");
+        return Size;
+      }
+      void setSize(int N) {
+        assert(N == -1 || N > 0 && "unexpected size value");
+        Size = N;
+      }
+      bool isSizeUnbounded() {
+        return Size == -1;
+      }
   }];
 }
 
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index 018e7ea5901a2b..ce262fd41dff37 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -30,12 +30,60 @@ class ParsedAttr;
 class Scope;
 class VarDecl;
 
+using llvm::dxil::ResourceClass;
+
 // FIXME: This can be hidden (as static function in SemaHLSL.cpp) once we no
 // longer need to create builtin buffer types in HLSLExternalSemaSource.
 bool CreateHLSLAttributedResourceType(
     Sema &S, QualType Wrapped, ArrayRef<const Attr *> AttrList,
     QualType &ResType, HLSLAttributedResourceLocInfo *LocInfo = nullptr);
 
+enum class BindingType : uint8_t { NotAssigned, Explicit, Implicit };
+
+// DeclBindingInfo struct stores information about required/assigned resource
+// binding onon a declaration for specific resource class.
+struct DeclBindingInfo {
+  const VarDecl *Decl;
+  ResourceClass ResClass;
+  int Size; // -1 == unbounded array
+  const HLSLResourceBindingAttr *Attr;
+  BindingType BindType;
+
+  DeclBindingInfo(const VarDecl *Decl, ResourceClass ResClass, int Size = 0,
+                  BindingType BindType = BindingType::NotAssigned,
+                  const HLSLResourceBindingAttr *Attr = nullptr)
+      : Decl(Decl), ResClass(ResClass), Size(Size), Attr(Attr),
+        BindType(BindType) {}
+
+  void setBindingAttribute(HLSLResourceBindingAttr *A, BindingType BT) {
+    assert(Attr == nullptr && BindType == BindingType::NotAssigned &&
+           "binding attribute already assigned");
+    Attr = A;
+    BindType = BT;
+  }
+};
+
+// ResourceBindings class stores information about all resource bindings
+// in a shader. It is used for binding diagnostics and implicit binding
+// assigments.
+class ResourceBindings {
+public:
+  DeclBindingInfo *addDeclBindingInfo(const VarDecl *VD, ResourceClass ResClass,
+                                      int Size);
+  DeclBindingInfo *getDeclBindingInfo(const VarDecl *VD,
+                                      ResourceClass ResClass);
+  bool hasBindingInfoForDecl(const VarDecl *VD);
+
+private:
+  // List of all resource bindings required by the shader.
+  // A global declaration can have multiple bindings for different
+  // resource classes. They are all stored sequentially in this list.
+  // The DeclToBindingListIndex hashtable maps a declaration to the
+  // index of the first binding info in the list.
+  llvm::SmallVector<DeclBindingInfo> BindingsList;
+  llvm::DenseMap<const VarDecl *, unsigned> DeclToBindingListIndex;
+};
+
 class SemaHLSL : public SemaBase {
 public:
   SemaHLSL(Sema &S);
@@ -56,6 +104,7 @@ class SemaHLSL : public SemaBase {
   mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
                          HLSLParamModifierAttr::Spelling Spelling);
   void ActOnTopLevelFunction(FunctionDecl *FD);
+  void ActOnVariableDeclarator(VarDecl *VD);
   void CheckEntryPoint(FunctionDecl *FD);
   void CheckSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param,
                                const HLSLAnnotationAttr *AnnotationAttr);
@@ -63,7 +112,6 @@ class SemaHLSL : public SemaBase {
       const Attr *A, llvm::Triple::EnvironmentType Stage,
       std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages);
   void DiagnoseAvailabilityViolations(TranslationUnitDecl *TU);
-  void ProcessResourceBindingOnDecl(VarDecl *D);
 
   QualType handleVectorBinOpConversion(ExprResult &LHS, ExprResult &RHS,
                                        QualType LHSType, QualType RHSType,
@@ -104,6 +152,15 @@ class SemaHLSL : public SemaBase {
   llvm::DenseMap<const HLSLAttributedResourceType *,
                  HLSLAttributedResourceLocInfo>
       LocsForHLSLAttributedResources;
+
+  // List of all resource bindings
+  ResourceBindings Bindings;
+
+private:
+  void FindResourcesOnVarDecl(VarDecl *D);
+  void FindResourcesOnUserRecordDecl(const VarDecl *VD, const RecordType *RT,
+                                     int Size);
+  void ProcessExplicitBindingsOnDecl(VarDecl *D);
 };
 
 } // namespace clang
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index 8e27a5e068e702..770d00710a6816 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -7877,7 +7877,7 @@ NamedDecl *Sema::ActOnVariableDeclarator(
   ProcessDeclAttributes(S, NewVD, D);
 
   if (getLangOpts().HLSL)
-    HLSL().ProcessResourceBindingOnDecl(NewVD);
+    HLSL().ActOnVariableDeclarator(NewVD);
 
   // FIXME: This is probably the wrong location to be doing this and we should
   // probably be doing this for more attributes (especially for function
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 5c27a74a853bba..197ee63c07deeb 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -40,8 +40,8 @@
 #include <utility>
 
 using namespace clang;
-using llvm::dxil::ResourceClass;
 using RegisterType = HLSLResourceBindingAttr::RegisterType;
+
 static RegisterType getRegisterType(ResourceClass RC) {
   switch (RC) {
   case ResourceClass::SRV:
@@ -81,6 +81,49 @@ static RegisterType getRegisterType(StringRef Slot) {
   }
 }
 
+static ResourceClass getResourceClass(RegisterType RT) {
+  switch (RT) {
+  case RegisterType::SRV:
+    return ResourceClass::SRV;
+  case RegisterType::UAV:
+    return ResourceClass::UAV;
+  case RegisterType::CBuffer:
+    return ResourceClass::CBuffer;
+  case RegisterType::Sampler:
+    return ResourceClass::Sampler;
+  default:
+    llvm_unreachable("unexpected RegisterType value");
+  }
+}
+
+DeclBindingInfo *ResourceBindings::addDeclBindingInfo(const VarDecl *VD,
+                                                      ResourceClass ResClass,
+                                                      int Size) {
+  assert(getDeclBindingInfo(VD, ResClass) == nullptr &&
+         "DeclBindingInfo already added");
+  if (DeclToBindingListIndex.find(VD) == DeclToBindingListIndex.end())
+    DeclToBindingListIndex[VD] = BindingsList.size();
+  return &BindingsList.emplace_back(DeclBindingInfo(VD, ResClass, Size));
+}
+
+DeclBindingInfo *ResourceBindings::getDeclBindingInfo(const VarDecl *VD,
+                                                      ResourceClass ResClass) {
+  auto Entry = DeclToBindingListIndex.find(VD);
+  if (Entry != DeclToBindingListIndex.end()) {
+    unsigned Index = Entry->getSecond();
+    while (Index < BindingsList.size() && BindingsList[Index].Decl == VD) {
+      if (BindingsList[Index].ResClass == ResClass)
+        return &BindingsList[Index];
+      Index++;
+    }
+  }
+  return nullptr;
+}
+
+bool ResourceBindings::hasBindingInfoForDecl(const VarDecl *VD) {
+  return DeclToBindingListIndex.contains(VD);
+}
+
 SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S) {}
 
 Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer,
@@ -983,14 +1026,11 @@ SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) {
   return LocInfo;
 }
 
-// Returns handle type of a resource, if the VarDecl is a resource
-// or an array of resources
+// Returns handle type of a resource, if the type is a resource
 static const HLSLAttributedResourceType *
-FindHandleTypeOnResource(const VarDecl *VD) {
-  // If VarDecl is a resource class, the first field must
+FindHandleTypeOnResource(const Type *Ty) {
+  // If Ty is a resource class, the first field must
   // be the resource handle of type HLSLAttributedResourceType
-  assert(VD != nullptr && "expected VarDecl");
-  const Type *Ty = VD->getType()->getPointeeOrArrayElementType();
   if (RecordDecl *RD = Ty->getAsCXXRecordDecl()) {
     if (!RD->fields().empty()) {
       const auto &FirstFD = RD->fields().begin();
@@ -1001,51 +1041,53 @@ FindHandleTypeOnResource(const VarDecl *VD) {
   return nullptr;
 }
 
-// Walks though the user defined record type, finds resource class
-// that matches the RegisterBinding.Type and assigns it to
-// RegisterBinding::Decl.
-static bool
-ProcessResourceBindingOnUserRecordDecl(const RecordType *RT,
-                                       HLSLResourceBindingAttr *RBA) {
-
-  llvm::SmallVector<const Type *> TypesToScan;
-  TypesToScan.emplace_back(RT);
-  RegisterType RegType = RBA->getRegisterType();
-
-  while (!TypesToScan.empty()) {
-    const Type *T = TypesToScan.pop_back_val();
-
-    while (T->isArrayType()) {
-      // FIXME: calculate the binding size from the array dimensions (or
-      // unbounded for unsized array) size *= (size_of_array);
-      T = T->getArrayElementTypeNoTypeQual();
-    }
-    if (T->isIntegralOrEnumerationType() || T->isFloatingType()) {
-      if (RegType == RegisterType::C)
-        return true;
+// Returns handle type of a resource, if the VarDecl is a resource
+static const HLSLAttributedResourceType *
+FindHandleTypeOnResource(const VarDecl *VD) {
+  assert(VD != nullptr && "expected VarDecl");
+  return FindHandleTypeOnResource(VD->getType().getTypePtr());
+}
+
+// Walks though the global variable declaration, collects all resource binding
+// requirements and adds them to Bindings
+void SemaHLSL::FindResourcesOnUserRecordDecl(const VarDecl *VD,
+                                             const RecordType *RT, int Size) {
+  const RecordDecl *RD = RT->getDecl();
+  for (FieldDecl *FD : RD->fields()) {
+    const Type *Ty = FD->getType()->getUnqualifiedDesugaredType();
+
+    // Calculate array size and unwrap
+    int ArraySize = 1;
+    assert(!Ty->isIncompleteArrayType() &&
+           "incomplete arrays inside user defined types are not supported");
+    while (Ty->isConstantArrayType()) {
+      const ConstantArrayType *CAT = cast<ConstantArrayType>(Ty);
+      ArraySize *= CAT->getSize().getSExtValue();
+      Ty = CAT->getElementType()->getUnqualifiedDesugaredType();
     }
-    const RecordType *RT = T->getAs<RecordType>();
-    if (!RT)
+
+    if (!Ty->isRecordType())
       continue;
 
-    const RecordDecl *RD = RT->getDecl();
-    for (FieldDecl *FD : RD->fields()) {
-      const Type *FieldTy = FD->getType().getTypePtr();
-      if (const HLSLAttributedResourceType *AttrResType =
-              dyn_cast<HLSLAttributedResourceType>(FieldTy)) {
-        ResourceClass RC = AttrResType->getAttrs().ResourceClass;
-        if (getRegisterType(RC) == RegType) {
-          assert(RBA->getResourceField() == nullptr &&
-                 "multiple register bindings of the same type are not allowed");
-          RBA->setResourceField(FD);
-          return true;
-        }
-      } else {
-        TypesToScan.emplace_back(FD->getType().getTypePtr());
-      }
+    // Field is a resource or array of resources
+    if (const HLSLAttributedResourceType *AttrResType =
+            FindHandleTypeOnResource(Ty)) {
+      ResourceClass RC = AttrResType->getAttrs().ResourceClass;
+
+      // Add a new DeclBindingInfo to Bindings. Update the binding size if
+      // a binding info already exists (there are multiple resources of same
+      // resource class in this user decl)
+      if (auto *DBI = Bindings.getDeclBindingInfo(VD, RC))
+        DBI->Size += Size * ArraySize;
+      else
+        Bindings.addDeclBindingInfo(VD, RC, Size);
+    } else if (const RecordType *RT = dyn_cast<RecordType>(Ty)) {
+      // Recursively scan embedded struct or class; it would be nice to do this
+      // without recursion, but tricky to corrently calculate the size.
+      // Hopefully nesting of structs in structs too many levels is unlikely.
+      FindResourcesOnUserRecordDecl(VD, RT, Size);
     }
   }
-  return false;
 }
 
 // return false if the register binding is not valid
@@ -1093,11 +1135,9 @@ static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc,
 
   // Basic types
   if (Ty->isArithmeticType()) {
-    bool IsValid = true;
     bool DeclaredInCOrTBuffer = isa<HLSLBufferDecl>(D->getDeclContext());
     if (SpecifiedSpace && !DeclaredInCOrTBuffer) {
       S.Diag(ArgLoc, diag::err_hlsl_space_on_global_constant);
-      IsValid = false;
     }
 
     if (!DeclaredInCOrTBuffer &&
@@ -1107,17 +1147,15 @@ static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc,
         S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b);
       else if (RegType != RegisterType::C) {
         S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
-        IsValid = false;
       }
     } else {
       if (RegType == RegisterType::C)
         S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset);
       else {
         S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
-        IsValid = false;
       }
     }
-    return IsValid;
+    return false;
   }
   if (Ty->isRecordType())
     // RecordTypes will be diagnosed in ProcessResourceBindingOnDecl
@@ -2057,6 +2095,7 @@ bool SemaHLSL::IsIntangibleType(clang::QualType QT) {
   CXXRecordDecl *RD = RT->getAsCXXRecordDecl();
   assert(RD != nullptr &&
          "all HLSL struct and classes should be CXXRecordDecl");
+  assert(RD->isCompleteDefinition() && "expecting complete type");
   return RD->isHLSLIntangible();
 }
 
@@ -2243,61 +2282,106 @@ QualType SemaHLSL::getInoutParameterType(QualType Ty) {
   return Ty;
 }
 
-// Walks though existing explicit bindings, finds the actual resource class
-// decl the binding applies to and sets it to attr->ResourceField.
-// Additional processing of resource binding can be added here later on,
-// such as preparation for overapping resource detection or implicit binding.
-void SemaHLSL::ProcessResourceBindingOnDecl(VarDecl *D) {
-  if (!D->hasGlobalStorage())
+void SemaHLSL::ActOnVariableDeclarator(VarDecl *VD) {
+  if (VD->hasGlobalStorage()) {
+    // make sure the declaration has a complete type
+    if (SemaRef.RequireCompleteType(
+            VD->getLocation(),
+            SemaRef.getASTContext().getBaseElementType(VD->getType()),
+            diag::err_typecheck_decl_incomplete_type)) {
+      VD->setInvalidDecl();
+      return;
+    }
+
+    // find all resources on decl
+    if (IsIntangibleType(VD->getType()))
+      FindResourcesOnVarDecl(VD);
+
+    // process explicit bindings
+    ProcessExplicitBindingsOnDecl(VD);
+  }
+}
+
+// Walks though the global variable declaration, collects all resource binding
+// requirements and adds them to Bindings
+void SemaHLSL::FindResourcesOnVarDecl(VarDecl *VD) {
+  assert(VD->hasGlobalStorage() && IsIntangibleType(VD->getType()) &&
+         "expected global variable that contains HLSL resource");
+
+  // Cbuffers and Tbuffers are HLSLBufferDecl types
+  if (const HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(VD)) {
+    Bindings.addDeclBindingInfo(VD,
+                                CBufferOrTBuffer->isCBuffer()
+                                    ? ResourceClass::CBuffer
+                                    : ResourceClass::SRV,
+                                1);
     return;
+  }
 
-  for (Attr *A : D->attrs()) {
-    HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(A);
-    if (!RBA)
-      continue;
+  // Calculate size of array and unwrap
+  int Size = 1;
+  const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
+  if (Ty->isIncompleteArrayType())
+    Size = -1;
+  while (Ty->isConstantArrayType()) {
+    const ConstantArrayType *CAT = cast<ConstantArrayType>(Ty);
+    Size *= CAT->getSize().getSExtValue();
+    Ty = CAT->getElementType()->getUnqualifiedDesugaredType();
+  }
 
-    // // Cbuffers and Tbuffers are HLSLBufferDecl types
-    if (const HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) {
-      assert(RBA->getRegisterType() ==
-                 getRegisterType(CBufferOrTBuffer->isCBuffer()
-                                     ? ResourceClass::CBuffer
-                                     : ResourceClass::SRV) &&
-             "this should have been handled in DiagnoseLocalRegisterBinding");
-      // should we handle HLSLBufferDecl here?
-      continue;
-    }
+  // Resource (or array of resources)
+  if (const HLSLAttributedResourceType *AttrResType =
+          FindHandleTypeOnResource(Ty)) {
+    Bindings.addDeclBindingInfo(VD, AttrResType->getAttrs().ResourceClass,
+                                Size);
+    return;
+  }
 
-    // Samplers, UAVs, and SRVs are VarDecl types
-    assert(isa<VarDecl>(D) && "D is expected to be VarDecl or HLSLBufferDecl");
-    const VarDecl *VD = cast<VarDecl>(D);
+  assert(Size != -1 &&
+         "unbounded arrays of user defined types are not supported");
 
-    // Register binding directly on global resource class variable
-    if (const HLSLAttributedResourceType *AttrResType =
-            FindHandleTypeOnResource(VD)) {
-      // FIXME: if array, calculate the binding size from the array dimensions
-      // (or unbounded for unsized array)
-      assert(RBA->getResourceField() == nullptr);
+  // User defined record type
+  if (const RecordType *RT = dyn_cast<RecordType>(Ty))
+    FindResourcesOnUserRecordDecl(VD, RT, Size);
+}
+
+// Walks though the explicit resource binding attributes on the declaration,
+// and makes sure there is a resource that matched the binding and updates
+// DeclBindingInfoLists
+void SemaHLSL::ProcessExplicitBindingsOnDecl(VarDecl *VD) {
+  assert(VD->hasGlobalStorage() && "expected global variable");
+
+  for (Attr *A : VD->attrs()) {
+    HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(A);
+    if (!RBA)
       continue;
-    }
 
-    // Global array
-    const clang::Type *Ty = VD->getType().getTypePtr();
-    while (Ty->isArrayType()) {
-      Ty = Ty->getArrayElementTypeNoTypeQual();
-    }
+    RegisterType RT = RBA->getRegisterType();
+    assert(RT != RegisterType::I && RT != RegisterType::Invalid &&
+           "invalid or obsolete register type should never have an attribute "
+           "created");
 
-    // Basic types
-    if (Ty->isArithmeticType()) {
+    // These were already diagnosed earlier
+    if (RT == RegisterType::C) {
+      if (Bindings.hasBindingInfoForDecl(VD))
+        SemaRef.Diag(VD->getLocation(),
+                     diag::warn_hlsl_user_defined_type_missing_member)
+            << static_cast<int>(RT);
       continue;
     }
 
-    if (Ty->isRecordType()) {
-      if (!ProcessResourceBindingOnUserRecordDecl(Ty->getAs<RecordType>(),
-                                                  RBA)) {
-        SemaRef.Diag(D->getLocation(),
-                     diag::warn_hlsl_user_defined_type_missing_member)
-            << static_cast<int>(RBA->getRegisterType());
-      }
+    // Find DeclBindingInfo for this binding and update it, or report error
+    // if it does not exist (user type does to contain resources with the
+    // expected resource class).
+    ResourceClass RC = getResourceClass(RT);
+    if (DeclBindingInfo *BI = Bindings.getDeclBindingInfo(VD, RC)) {
+      // update binding info
+      RBA->setSize(BI->Size);
+      BI->setBindingAttribute(RBA, BindingType::Explicit);
+    } else {
+      SemaRef.Diag(VD->getLocation(),
+                   diag::warn_hlsl_user_defined_type_missing_member)
+          << static_cast<int>(RT);
     }
   }
 }

>From aa6247f414b2bd3d39f349646f3a97ec72d5d517 Mon Sep 17 00:00:00 2001
From: Helena Kotas <hekotas at microsoft.com>
Date: Wed, 9 Oct 2024 17:08:25 -0700
Subject: [PATCH 4/6] removed unused variable, cleanup

---
 clang/lib/Sema/SemaHLSL.cpp | 14 +++-----------
 1 file changed, 3 insertions(+), 11 deletions(-)

diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 197ee63c07deeb..0423340ee5fc4f 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1136,24 +1136,21 @@ static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc,
   // Basic types
   if (Ty->isArithmeticType()) {
     bool DeclaredInCOrTBuffer = isa<HLSLBufferDecl>(D->getDeclContext());
-    if (SpecifiedSpace && !DeclaredInCOrTBuffer) {
+    if (SpecifiedSpace && !DeclaredInCOrTBuffer)
       S.Diag(ArgLoc, diag::err_hlsl_space_on_global_constant);
-    }
 
     if (!DeclaredInCOrTBuffer &&
         (Ty->isIntegralType(S.getASTContext()) || Ty->isFloatingType())) {
       // Default Globals
       if (RegType == RegisterType::CBuffer)
         S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b);
-      else if (RegType != RegisterType::C) {
+      else if (RegType != RegisterType::C)
         S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
-      }
     } else {
       if (RegType == RegisterType::C)
         S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset);
-      else {
+      else
         S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
-      }
     }
     return false;
   }
@@ -1172,13 +1169,8 @@ static bool ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
   // make sure that there are no two register annotations
   // applied to the decl with the same register type
   bool RegisterTypesDetected[5] = {false};
-
   RegisterTypesDetected[static_cast<int>(regType)] = true;
 
-  // we need a static map to keep track of previous conflicts
-  // so that we don't emit the same error multiple times
-  static std::map<Decl *, std::set<RegisterType>> PreviousConflicts;
-
   for (auto it = TheDecl->attr_begin(); it != TheDecl->attr_end(); ++it) {
     if (HLSLResourceBindingAttr *attr =
             dyn_cast<HLSLResourceBindingAttr>(*it)) {

>From a6edabe43eefc2957932498ee35b71e800af9fdd Mon Sep 17 00:00:00 2001
From: Helena Kotas <hekotas at microsoft.com>
Date: Tue, 15 Oct 2024 21:14:19 -0700
Subject: [PATCH 5/6] Code review feedback

- remove size calculation and storage - it is currently not used or tested
- remove invalid register type
- set fields on HLSLResourceBindingAttr as private and accessors public, add const
- update function names
- update comments
- use more effective SmallVector and DenseMap methods
---
 clang/include/clang/Basic/Attr.td   |  31 +++----
 clang/include/clang/Sema/SemaHLSL.h |  16 ++--
 clang/lib/Sema/SemaHLSL.cpp         | 121 ++++++++++++++--------------
 3 files changed, 75 insertions(+), 93 deletions(-)

diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index 916757ccbe2d47..0259b6e40ca962 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4594,42 +4594,29 @@ def HLSLResourceBinding: InheritableAttr {
   let Args = [StringArgument<"Slot">, StringArgument<"Space", 1>];
   let Documentation = [HLSLResourceBindingDocs];
   let AdditionalMembers = [{
-      enum class RegisterType : unsigned { SRV, UAV, CBuffer, Sampler, C, I, Invalid };
-
+  public:
+      enum class RegisterType : unsigned { SRV, UAV, CBuffer, Sampler, C, I };
+  
+  private:
       RegisterType RegType;
       unsigned SlotNumber;
       unsigned SpaceNumber;
-      
-      // Size of the binding
-      // 0 == not set
-      //-1 == unbounded
-      int Size;
 
-      void setBinding(RegisterType RT, unsigned SlotNum, unsigned SpaceNum, int Size = 0) {
+  public:
+      void setBinding(RegisterType RT, unsigned SlotNum, unsigned SpaceNum) {
         RegType = RT;
         SlotNumber = SlotNum;
         SpaceNumber = SpaceNum;
       }
-      RegisterType getRegisterType() {
+      RegisterType getRegisterType() const {
         return RegType;
       }
-      unsigned getSlotNumber() {
+      unsigned getSlotNumber() const {
         return SlotNumber;
       }
-      unsigned getSpaceNumber() {
+      unsigned getSpaceNumber() const {
         return SpaceNumber;
       }
-      unsigned getSize() {
-        assert(Size == -1 || Size > 0 && "size not set");
-        return Size;
-      }
-      void setSize(int N) {
-        assert(N == -1 || N > 0 && "unexpected size value");
-        Size = N;
-      }
-      bool isSizeUnbounded() {
-        return Size == -1;
-      }
   }];
 }
 
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index ce262fd41dff37..5eda4d544a5ae5 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -45,15 +45,13 @@ enum class BindingType : uint8_t { NotAssigned, Explicit, Implicit };
 struct DeclBindingInfo {
   const VarDecl *Decl;
   ResourceClass ResClass;
-  int Size; // -1 == unbounded array
   const HLSLResourceBindingAttr *Attr;
   BindingType BindType;
 
   DeclBindingInfo(const VarDecl *Decl, ResourceClass ResClass, int Size = 0,
                   BindingType BindType = BindingType::NotAssigned,
                   const HLSLResourceBindingAttr *Attr = nullptr)
-      : Decl(Decl), ResClass(ResClass), Size(Size), Attr(Attr),
-        BindType(BindType) {}
+      : Decl(Decl), ResClass(ResClass), Attr(Attr), BindType(BindType) {}
 
   void setBindingAttribute(HLSLResourceBindingAttr *A, BindingType BT) {
     assert(Attr == nullptr && BindType == BindingType::NotAssigned &&
@@ -68,8 +66,8 @@ struct DeclBindingInfo {
 // assigments.
 class ResourceBindings {
 public:
-  DeclBindingInfo *addDeclBindingInfo(const VarDecl *VD, ResourceClass ResClass,
-                                      int Size);
+  DeclBindingInfo *addDeclBindingInfo(const VarDecl *VD,
+                                      ResourceClass ResClass);
   DeclBindingInfo *getDeclBindingInfo(const VarDecl *VD,
                                       ResourceClass ResClass);
   bool hasBindingInfoForDecl(const VarDecl *VD);
@@ -157,10 +155,10 @@ class SemaHLSL : public SemaBase {
   ResourceBindings Bindings;
 
 private:
-  void FindResourcesOnVarDecl(VarDecl *D);
-  void FindResourcesOnUserRecordDecl(const VarDecl *VD, const RecordType *RT,
-                                     int Size);
-  void ProcessExplicitBindingsOnDecl(VarDecl *D);
+  void collectResourcesOnVarDecl(VarDecl *D);
+  void collectResourcesOnUserRecordDecl(const VarDecl *VD,
+                                        const RecordType *RT);
+  void processExplicitBindingsOnDecl(VarDecl *D);
 };
 
 } // namespace clang
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 668d3ad9ecd6ba..a58c4281eeb375 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -56,28 +56,37 @@ static RegisterType getRegisterType(ResourceClass RC) {
   llvm_unreachable("unexpected ResourceClass value");
 }
 
-static RegisterType getRegisterType(StringRef Slot) {
+// Converts the first letter of string Slot to RegisterType.
+// Returns false if the letter does not correspond to a valid register type.
+static bool convertToRegisterType(StringRef Slot, RegisterType *RT) {
+  assert(RT != nullptr);
   switch (Slot[0]) {
   case 't':
   case 'T':
-    return RegisterType::SRV;
+    *RT = RegisterType::SRV;
+    return true;
   case 'u':
   case 'U':
-    return RegisterType::UAV;
+    *RT = RegisterType::UAV;
+    return true;
   case 'b':
   case 'B':
-    return RegisterType::CBuffer;
+    *RT = RegisterType::CBuffer;
+    return true;
   case 's':
   case 'S':
-    return RegisterType::Sampler;
+    *RT = RegisterType::Sampler;
+    return true;
   case 'c':
   case 'C':
-    return RegisterType::C;
+    *RT = RegisterType::C;
+    return true;
   case 'i':
   case 'I':
-    return RegisterType::I;
+    *RT = RegisterType::I;
+    return true;
   default:
-    return RegisterType::Invalid;
+    return false;
   }
 }
 
@@ -91,19 +100,18 @@ static ResourceClass getResourceClass(RegisterType RT) {
     return ResourceClass::CBuffer;
   case RegisterType::Sampler:
     return ResourceClass::Sampler;
-  default:
+  case RegisterType::C:
+  case RegisterType::I:
     llvm_unreachable("unexpected RegisterType value");
   }
 }
 
 DeclBindingInfo *ResourceBindings::addDeclBindingInfo(const VarDecl *VD,
-                                                      ResourceClass ResClass,
-                                                      int Size) {
+                                                      ResourceClass ResClass) {
   assert(getDeclBindingInfo(VD, ResClass) == nullptr &&
          "DeclBindingInfo already added");
-  if (DeclToBindingListIndex.find(VD) == DeclToBindingListIndex.end())
-    DeclToBindingListIndex[VD] = BindingsList.size();
-  return &BindingsList.emplace_back(DeclBindingInfo(VD, ResClass, Size));
+  DeclToBindingListIndex.try_emplace(VD, BindingsList.size());
+  return &BindingsList.emplace_back(VD, ResClass);
 }
 
 DeclBindingInfo *ResourceBindings::getDeclBindingInfo(const VarDecl *VD,
@@ -1050,19 +1058,18 @@ FindHandleTypeOnResource(const VarDecl *VD) {
 
 // Walks though the global variable declaration, collects all resource binding
 // requirements and adds them to Bindings
-void SemaHLSL::FindResourcesOnUserRecordDecl(const VarDecl *VD,
-                                             const RecordType *RT, int Size) {
+void SemaHLSL::collectResourcesOnUserRecordDecl(const VarDecl *VD,
+                                                const RecordType *RT) {
   const RecordDecl *RD = RT->getDecl();
   for (FieldDecl *FD : RD->fields()) {
     const Type *Ty = FD->getType()->getUnqualifiedDesugaredType();
 
-    // Calculate array size and unwrap
-    int ArraySize = 1;
+    // Unwrap arrays
+    // FIXME: Calculate array size while unwrapping
     assert(!Ty->isIncompleteArrayType() &&
            "incomplete arrays inside user defined types are not supported");
     while (Ty->isConstantArrayType()) {
       const ConstantArrayType *CAT = cast<ConstantArrayType>(Ty);
-      ArraySize *= CAT->getSize().getSExtValue();
       Ty = CAT->getElementType()->getUnqualifiedDesugaredType();
     }
 
@@ -1074,23 +1081,26 @@ void SemaHLSL::FindResourcesOnUserRecordDecl(const VarDecl *VD,
             FindHandleTypeOnResource(Ty)) {
       ResourceClass RC = AttrResType->getAttrs().ResourceClass;
 
-      // Add a new DeclBindingInfo to Bindings. Update the binding size if
-      // a binding info already exists (there are multiple resources of same
-      // resource class in this user decl)
-      if (auto *DBI = Bindings.getDeclBindingInfo(VD, RC))
-        DBI->Size += Size * ArraySize;
-      else
-        Bindings.addDeclBindingInfo(VD, RC, Size);
+      // Add a new DeclBindingInfo to Bindings if it does not already exist
+      DeclBindingInfo *DBI = Bindings.getDeclBindingInfo(VD, RC);
+      if (!DBI)
+        Bindings.addDeclBindingInfo(VD, RC);
     } else if (const RecordType *RT = dyn_cast<RecordType>(Ty)) {
       // Recursively scan embedded struct or class; it would be nice to do this
-      // without recursion, but tricky to corrently calculate the size.
-      // Hopefully nesting of structs in structs too many levels is unlikely.
-      FindResourcesOnUserRecordDecl(VD, RT, Size);
+      // without recursion, but tricky to correctly calculate the size of the
+      // binding, which is something we are probably going to need to do later
+      // on. Hopefully nesting of structs in structs too many levels is
+      // unlikely.
+      collectResourcesOnUserRecordDecl(VD, RT);
     }
   }
 }
 
-// return false if the register binding is not valid
+// Diagnore localized register binding errors for a single binding; does not
+// diagnose resource binding on user record types, that will be done later
+// in processResourceBindingOnDecl based on the information collected in
+// collectResourcesOnVarDecl.
+// Returns false if the register binding is not valid.
 static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc,
                                          Decl *D, RegisterType RegType,
                                          bool SpecifiedSpace) {
@@ -1155,7 +1165,7 @@ static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc,
     return false;
   }
   if (Ty->isRecordType())
-    // RecordTypes will be diagnosed in ProcessResourceBindingOnDecl
+    // RecordTypes will be diagnosed in processResourceBindingOnDecl
     // that is called from ActOnVariableDeclarator
     return true;
 
@@ -1175,7 +1185,7 @@ static bool ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
     if (HLSLResourceBindingAttr *attr =
             dyn_cast<HLSLResourceBindingAttr>(*it)) {
 
-      RegisterType otherRegType = getRegisterType(attr->getSlot());
+      RegisterType otherRegType = attr->getRegisterType();
       if (RegisterTypesDetected[static_cast<int>(otherRegType)]) {
         int otherRegTypeNum = static_cast<int>(otherRegType);
         S.Diag(TheDecl->getLocation(),
@@ -1250,13 +1260,12 @@ void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {
 
   // Validate.
   if (!Slot.empty()) {
-    RegType = getRegisterType(Slot);
-    if (RegType == RegisterType::I) {
-      Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_i);
+    if (!convertToRegisterType(Slot, &RegType)) {
+      Diag(ArgLoc, diag::err_hlsl_binding_type_invalid) << Slot.substr(0, 1);
       return;
     }
-    if (RegType == RegisterType::Invalid) {
-      Diag(ArgLoc, diag::err_hlsl_binding_type_invalid) << Slot.substr(0, 1);
+    if (RegType == RegisterType::I) {
+      Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_i);
       return;
     }
 
@@ -2294,60 +2303,51 @@ void SemaHLSL::ActOnVariableDeclarator(VarDecl *VD) {
 
     // find all resources on decl
     if (IsIntangibleType(VD->getType()))
-      FindResourcesOnVarDecl(VD);
+      collectResourcesOnVarDecl(VD);
 
     // process explicit bindings
-    ProcessExplicitBindingsOnDecl(VD);
+    processExplicitBindingsOnDecl(VD);
   }
 }
 
 // Walks though the global variable declaration, collects all resource binding
 // requirements and adds them to Bindings
-void SemaHLSL::FindResourcesOnVarDecl(VarDecl *VD) {
+void SemaHLSL::collectResourcesOnVarDecl(VarDecl *VD) {
   assert(VD->hasGlobalStorage() && IsIntangibleType(VD->getType()) &&
          "expected global variable that contains HLSL resource");
 
   // Cbuffers and Tbuffers are HLSLBufferDecl types
   if (const HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(VD)) {
-    Bindings.addDeclBindingInfo(VD,
-                                CBufferOrTBuffer->isCBuffer()
-                                    ? ResourceClass::CBuffer
-                                    : ResourceClass::SRV,
-                                1);
+    Bindings.addDeclBindingInfo(VD, CBufferOrTBuffer->isCBuffer()
+                                        ? ResourceClass::CBuffer
+                                        : ResourceClass::SRV);
     return;
   }
 
-  // Calculate size of array and unwrap
-  int Size = 1;
+  // Unwrap arrays
+  // FIXME: Calculate array size while unwrapping
   const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
-  if (Ty->isIncompleteArrayType())
-    Size = -1;
   while (Ty->isConstantArrayType()) {
     const ConstantArrayType *CAT = cast<ConstantArrayType>(Ty);
-    Size *= CAT->getSize().getSExtValue();
     Ty = CAT->getElementType()->getUnqualifiedDesugaredType();
   }
 
   // Resource (or array of resources)
   if (const HLSLAttributedResourceType *AttrResType =
           FindHandleTypeOnResource(Ty)) {
-    Bindings.addDeclBindingInfo(VD, AttrResType->getAttrs().ResourceClass,
-                                Size);
+    Bindings.addDeclBindingInfo(VD, AttrResType->getAttrs().ResourceClass);
     return;
   }
 
-  assert(Size != -1 &&
-         "unbounded arrays of user defined types are not supported");
-
   // User defined record type
   if (const RecordType *RT = dyn_cast<RecordType>(Ty))
-    FindResourcesOnUserRecordDecl(VD, RT, Size);
+    collectResourcesOnUserRecordDecl(VD, RT);
 }
 
 // Walks though the explicit resource binding attributes on the declaration,
 // and makes sure there is a resource that matched the binding and updates
 // DeclBindingInfoLists
-void SemaHLSL::ProcessExplicitBindingsOnDecl(VarDecl *VD) {
+void SemaHLSL::processExplicitBindingsOnDecl(VarDecl *VD) {
   assert(VD->hasGlobalStorage() && "expected global variable");
 
   for (Attr *A : VD->attrs()) {
@@ -2356,11 +2356,9 @@ void SemaHLSL::ProcessExplicitBindingsOnDecl(VarDecl *VD) {
       continue;
 
     RegisterType RT = RBA->getRegisterType();
-    assert(RT != RegisterType::I && RT != RegisterType::Invalid &&
-           "invalid or obsolete register type should never have an attribute "
-           "created");
+    assert(RT != RegisterType::I && "invalid or obsolete register type should "
+                                    "never have an attribute created");
 
-    // These were already diagnosed earlier
     if (RT == RegisterType::C) {
       if (Bindings.hasBindingInfoForDecl(VD))
         SemaRef.Diag(VD->getLocation(),
@@ -2375,7 +2373,6 @@ void SemaHLSL::ProcessExplicitBindingsOnDecl(VarDecl *VD) {
     ResourceClass RC = getResourceClass(RT);
     if (DeclBindingInfo *BI = Bindings.getDeclBindingInfo(VD, RC)) {
       // update binding info
-      RBA->setSize(BI->Size);
       BI->setBindingAttribute(RBA, BindingType::Explicit);
     } else {
       SemaRef.Diag(VD->getLocation(),

>From a7fbeaebf2a5ac502ab1c00787e0c51ee807210e Mon Sep 17 00:00:00 2001
From: Helena Kotas <hekotas at microsoft.com>
Date: Wed, 16 Oct 2024 10:14:57 -0700
Subject: [PATCH 6/6] More cleanup - rename function to lowerCase and remove
 one overload, remove Size argument, remove comment

---
 clang/include/clang/Sema/SemaHLSL.h |  2 +-
 clang/lib/Sema/SemaHLSL.cpp         | 19 +++++--------------
 2 files changed, 6 insertions(+), 15 deletions(-)

diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index 5eda4d544a5ae5..31b4c5b748e189 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -48,7 +48,7 @@ struct DeclBindingInfo {
   const HLSLResourceBindingAttr *Attr;
   BindingType BindType;
 
-  DeclBindingInfo(const VarDecl *Decl, ResourceClass ResClass, int Size = 0,
+  DeclBindingInfo(const VarDecl *Decl, ResourceClass ResClass,
                   BindingType BindType = BindingType::NotAssigned,
                   const HLSLResourceBindingAttr *Attr = nullptr)
       : Decl(Decl), ResClass(ResClass), Attr(Attr), BindType(BindType) {}
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 98c0100afbc5c9..de7951aa1b9088 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1036,7 +1036,7 @@ SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) {
 
 // Returns handle type of a resource, if the type is a resource
 static const HLSLAttributedResourceType *
-FindHandleTypeOnResource(const Type *Ty) {
+findHandleTypeOnResource(const Type *Ty) {
   // If Ty is a resource class, the first field must
   // be the resource handle of type HLSLAttributedResourceType
   if (RecordDecl *RD = Ty->getAsCXXRecordDecl()) {
@@ -1049,13 +1049,6 @@ FindHandleTypeOnResource(const Type *Ty) {
   return nullptr;
 }
 
-// Returns handle type of a resource, if the VarDecl is a resource
-static const HLSLAttributedResourceType *
-FindHandleTypeOnResource(const VarDecl *VD) {
-  assert(VD != nullptr && "expected VarDecl");
-  return FindHandleTypeOnResource(VD->getType().getTypePtr());
-}
-
 // Walks though the global variable declaration, collects all resource binding
 // requirements and adds them to Bindings
 void SemaHLSL::collectResourcesOnUserRecordDecl(const VarDecl *VD,
@@ -1076,12 +1069,10 @@ void SemaHLSL::collectResourcesOnUserRecordDecl(const VarDecl *VD,
     if (!Ty->isRecordType())
       continue;
 
-    // Field is a resource or array of resources
     if (const HLSLAttributedResourceType *AttrResType =
-            FindHandleTypeOnResource(Ty)) {
-      ResourceClass RC = AttrResType->getAttrs().ResourceClass;
-
+            findHandleTypeOnResource(Ty)) {
       // Add a new DeclBindingInfo to Bindings if it does not already exist
+      ResourceClass RC = AttrResType->getAttrs().ResourceClass;
       DeclBindingInfo *DBI = Bindings.getDeclBindingInfo(VD, RC);
       if (!DBI)
         Bindings.addDeclBindingInfo(VD, RC);
@@ -1130,7 +1121,7 @@ static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc,
 
   // Resource
   if (const HLSLAttributedResourceType *AttrResType =
-          FindHandleTypeOnResource(VD)) {
+          findHandleTypeOnResource(VD->getType().getTypePtr())) {
     if (RegType == getRegisterType(AttrResType->getAttrs().ResourceClass))
       return true;
 
@@ -2373,7 +2364,7 @@ void SemaHLSL::collectResourcesOnVarDecl(VarDecl *VD) {
 
   // Resource (or array of resources)
   if (const HLSLAttributedResourceType *AttrResType =
-          FindHandleTypeOnResource(Ty)) {
+          findHandleTypeOnResource(Ty)) {
     Bindings.addDeclBindingInfo(VD, AttrResType->getAttrs().ResourceClass);
     return;
   }



More information about the cfe-commits mailing list