[clang] 4512bbe - [HLSL] Collect explicit resource binding information (#111203)

via cfe-commits cfe-commits at lists.llvm.org
Wed Oct 16 21:24:17 PDT 2024


Author: Helena Kotas
Date: 2024-10-16T21:24:13-07:00
New Revision: 4512bbe7467c1c0f884304e5654d1070df58d6f8

URL: https://github.com/llvm/llvm-project/commit/4512bbe7467c1c0f884304e5654d1070df58d6f8
DIFF: https://github.com/llvm/llvm-project/commit/4512bbe7467c1c0f884304e5654d1070df58d6f8.diff

LOG: [HLSL] Collect explicit resource binding information (#111203)

Scans each global variable declaration and its members and collects all
required resource bindings in a new `SemaHLSL` data member `Bindings`.

New fields are added `HLSLResourceBindingAttr` for storing processed
binding information so that it can be used by CodeGen (`Bindings` or any
other Sema information is not accessible from CodeGen.)

Adjusts the existing register binding attribute handling and diagnostics
to:
- do not create HLSLResourceBindingAttribute if it is not valid
- diagnose only the simple/local errors when a register binding
attribute is parsed
- additional diagnostic of binding type mismatches is done later and
uses the new `Bindings` data

Fixes #110719

Added: 
    

Modified: 
    clang/include/clang/Basic/Attr.td
    clang/include/clang/Sema/SemaHLSL.h
    clang/lib/Sema/SemaDecl.cpp
    clang/lib/Sema/SemaHLSL.cpp
    clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index ec3d6e0079f630..0259b6e40ca962 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4593,6 +4593,31 @@ def HLSLResourceBinding: InheritableAttr {
   let LangOpts = [HLSL];
   let Args = [StringArgument<"Slot">, StringArgument<"Space", 1>];
   let Documentation = [HLSLResourceBindingDocs];
+  let AdditionalMembers = [{
+  public:
+      enum class RegisterType : unsigned { SRV, UAV, CBuffer, Sampler, C, I };
+  
+  private:
+      RegisterType RegType;
+      unsigned SlotNumber;
+      unsigned SpaceNumber;
+
+  public:
+      void setBinding(RegisterType RT, unsigned SlotNum, unsigned SpaceNum) {
+        RegType = RT;
+        SlotNumber = SlotNum;
+        SpaceNumber = SpaceNum;
+      }
+      RegisterType getRegisterType() const {
+        return RegType;
+      }
+      unsigned getSlotNumber() const {
+        return SlotNumber;
+      }
+      unsigned getSpaceNumber() const {
+        return SpaceNumber;
+      }
+  }];
 }
 
 def HLSLPackOffset: HLSLAnnotationAttr {

diff  --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index fa957abc9791af..4f1fc9a31404c6 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -28,6 +28,9 @@ class AttributeCommonInfo;
 class IdentifierInfo;
 class ParsedAttr;
 class Scope;
+class VarDecl;
+
+using llvm::dxil::ResourceClass;
 
 // FIXME: This can be hidden (as static function in SemaHLSL.cpp) once we no
 // longer need to create builtin buffer types in HLSLExternalSemaSource.
@@ -35,6 +38,50 @@ bool CreateHLSLAttributedResourceType(
     Sema &S, QualType Wrapped, ArrayRef<const Attr *> AttrList,
     QualType &ResType, HLSLAttributedResourceLocInfo *LocInfo = nullptr);
 
+enum class BindingType : uint8_t { NotAssigned, Explicit, Implicit };
+
+// DeclBindingInfo struct stores information about required/assigned resource
+// binding onon a declaration for specific resource class.
+struct DeclBindingInfo {
+  const VarDecl *Decl;
+  ResourceClass ResClass;
+  const HLSLResourceBindingAttr *Attr;
+  BindingType BindType;
+
+  DeclBindingInfo(const VarDecl *Decl, ResourceClass ResClass,
+                  BindingType BindType = BindingType::NotAssigned,
+                  const HLSLResourceBindingAttr *Attr = nullptr)
+      : Decl(Decl), ResClass(ResClass), Attr(Attr), BindType(BindType) {}
+
+  void setBindingAttribute(HLSLResourceBindingAttr *A, BindingType BT) {
+    assert(Attr == nullptr && BindType == BindingType::NotAssigned &&
+           "binding attribute already assigned");
+    Attr = A;
+    BindType = BT;
+  }
+};
+
+// ResourceBindings class stores information about all resource bindings
+// in a shader. It is used for binding diagnostics and implicit binding
+// assigments.
+class ResourceBindings {
+public:
+  DeclBindingInfo *addDeclBindingInfo(const VarDecl *VD,
+                                      ResourceClass ResClass);
+  DeclBindingInfo *getDeclBindingInfo(const VarDecl *VD,
+                                      ResourceClass ResClass);
+  bool hasBindingInfoForDecl(const VarDecl *VD) const;
+
+private:
+  // List of all resource bindings required by the shader.
+  // A global declaration can have multiple bindings for 
diff erent
+  // resource classes. They are all stored sequentially in this list.
+  // The DeclToBindingListIndex hashtable maps a declaration to the
+  // index of the first binding info in the list.
+  llvm::SmallVector<DeclBindingInfo> BindingsList;
+  llvm::DenseMap<const VarDecl *, unsigned> DeclToBindingListIndex;
+};
+
 class SemaHLSL : public SemaBase {
 public:
   SemaHLSL(Sema &S);
@@ -55,6 +102,7 @@ class SemaHLSL : public SemaBase {
   mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
                          HLSLParamModifierAttr::Spelling Spelling);
   void ActOnTopLevelFunction(FunctionDecl *FD);
+  void ActOnVariableDeclarator(VarDecl *VD);
   void CheckEntryPoint(FunctionDecl *FD);
   void CheckSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param,
                                const HLSLAnnotationAttr *AnnotationAttr);
@@ -102,6 +150,15 @@ class SemaHLSL : public SemaBase {
   llvm::DenseMap<const HLSLAttributedResourceType *,
                  HLSLAttributedResourceLocInfo>
       LocsForHLSLAttributedResources;
+
+  // List of all resource bindings
+  ResourceBindings Bindings;
+
+private:
+  void collectResourcesOnVarDecl(VarDecl *D);
+  void collectResourcesOnUserRecordDecl(const VarDecl *VD,
+                                        const RecordType *RT);
+  void processExplicitBindingsOnDecl(VarDecl *D);
 };
 
 } // namespace clang

diff  --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index fece22c663d00c..229c9080d558ec 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -7883,6 +7883,9 @@ NamedDecl *Sema::ActOnVariableDeclarator(
   // Handle attributes prior to checking for duplicates in MergeVarDecl
   ProcessDeclAttributes(S, NewVD, D);
 
+  if (getLangOpts().HLSL)
+    HLSL().ActOnVariableDeclarator(NewVD);
+
   // FIXME: This is probably the wrong location to be doing this and we should
   // probably be doing this for more attributes (especially for function
   // pointer attributes such as format, warn_unused_result, etc.). Ideally

diff  --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 698fdbed0484e5..0d23c4935e9196 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -40,9 +40,7 @@
 #include <utility>
 
 using namespace clang;
-using llvm::dxil::ResourceClass;
-
-enum class RegisterType { SRV, UAV, CBuffer, Sampler, C, I, Invalid };
+using RegisterType = HLSLResourceBindingAttr::RegisterType;
 
 static RegisterType getRegisterType(ResourceClass RC) {
   switch (RC) {
@@ -58,31 +56,95 @@ static RegisterType getRegisterType(ResourceClass RC) {
   llvm_unreachable("unexpected ResourceClass value");
 }
 
-static RegisterType getRegisterType(StringRef Slot) {
+// Converts the first letter of string Slot to RegisterType.
+// Returns false if the letter does not correspond to a valid register type.
+static bool convertToRegisterType(StringRef Slot, RegisterType *RT) {
+  assert(RT != nullptr);
   switch (Slot[0]) {
   case 't':
   case 'T':
-    return RegisterType::SRV;
+    *RT = RegisterType::SRV;
+    return true;
   case 'u':
   case 'U':
-    return RegisterType::UAV;
+    *RT = RegisterType::UAV;
+    return true;
   case 'b':
   case 'B':
-    return RegisterType::CBuffer;
+    *RT = RegisterType::CBuffer;
+    return true;
   case 's':
   case 'S':
-    return RegisterType::Sampler;
+    *RT = RegisterType::Sampler;
+    return true;
   case 'c':
   case 'C':
-    return RegisterType::C;
+    *RT = RegisterType::C;
+    return true;
   case 'i':
   case 'I':
-    return RegisterType::I;
+    *RT = RegisterType::I;
+    return true;
   default:
-    return RegisterType::Invalid;
+    return false;
   }
 }
 
+static ResourceClass getResourceClass(RegisterType RT) {
+  switch (RT) {
+  case RegisterType::SRV:
+    return ResourceClass::SRV;
+  case RegisterType::UAV:
+    return ResourceClass::UAV;
+  case RegisterType::CBuffer:
+    return ResourceClass::CBuffer;
+  case RegisterType::Sampler:
+    return ResourceClass::Sampler;
+  case RegisterType::C:
+  case RegisterType::I:
+    llvm_unreachable("unexpected RegisterType value");
+  }
+}
+
+DeclBindingInfo *ResourceBindings::addDeclBindingInfo(const VarDecl *VD,
+                                                      ResourceClass ResClass) {
+  assert(getDeclBindingInfo(VD, ResClass) == nullptr &&
+         "DeclBindingInfo already added");
+#ifndef NDEBUG
+  // Verify that existing bindings for this decl are stored sequentially
+  // and at the end of the BindingsList
+  auto I = DeclToBindingListIndex.find(VD);
+  if (I != DeclToBindingListIndex.end()) {
+    for (unsigned Index = I->getSecond(); Index < BindingsList.size(); ++Index)
+      assert(BindingsList[Index].Decl == VD);
+  }
+#endif
+  // VarDecl may have multiple entries for 
diff erent resource classes.
+  // DeclToBindingListIndex stores the index of the first binding we saw
+  // for this decl. If there are any additional ones then that index
+  // shouldn't be updated.
+  DeclToBindingListIndex.try_emplace(VD, BindingsList.size());
+  return &BindingsList.emplace_back(VD, ResClass);
+}
+
+DeclBindingInfo *ResourceBindings::getDeclBindingInfo(const VarDecl *VD,
+                                                      ResourceClass ResClass) {
+  auto Entry = DeclToBindingListIndex.find(VD);
+  if (Entry != DeclToBindingListIndex.end()) {
+    for (unsigned Index = Entry->getSecond();
+         Index < BindingsList.size() && BindingsList[Index].Decl == VD;
+         ++Index) {
+      if (BindingsList[Index].ResClass == ResClass)
+        return &BindingsList[Index];
+    }
+  }
+  return nullptr;
+}
+
+bool ResourceBindings::hasBindingInfoForDecl(const VarDecl *VD) const {
+  return DeclToBindingListIndex.contains(VD);
+}
+
 SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S) {}
 
 Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer,
@@ -985,88 +1047,85 @@ SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) {
   return LocInfo;
 }
 
-// get the record decl from a var decl that we expect
-// represents a resource
-static CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *VD) {
-  const Type *Ty = VD->getType()->getPointeeOrArrayElementType();
-  assert(Ty && "Resource must have an element type.");
-
-  if (Ty->isBuiltinType())
-    return nullptr;
-
-  CXXRecordDecl *TheRecordDecl = Ty->getAsCXXRecordDecl();
-  assert(TheRecordDecl && "Resource should have a resource type declaration.");
-  return TheRecordDecl;
-}
-
+// Returns handle type of a resource, if the type is a resource
 static const HLSLAttributedResourceType *
-findAttributedResourceTypeOnField(VarDecl *VD) {
-  assert(VD != nullptr && "expected VarDecl");
-  if (RecordDecl *RD = getRecordDeclFromVarDecl(VD)) {
-    for (auto *FD : RD->fields()) {
-      if (const HLSLAttributedResourceType *AttrResType =
-              dyn_cast<HLSLAttributedResourceType>(FD->getType().getTypePtr()))
-        return AttrResType;
+findHandleTypeOnResource(const Type *Ty) {
+  // If Ty is a resource class, the first field must
+  // be the resource handle of type HLSLAttributedResourceType
+  if (RecordDecl *RD = Ty->getAsCXXRecordDecl()) {
+    if (!RD->fields().empty()) {
+      const auto &FirstFD = RD->fields().begin();
+      return dyn_cast<HLSLAttributedResourceType>(
+          FirstFD->getType().getTypePtr());
     }
   }
   return nullptr;
 }
 
-// Iterate over RecordType fields and return true if any of them matched the
-// register type
-static bool ContainsResourceForRegisterType(Sema &S, const RecordType *RT,
-                                            RegisterType RegType) {
-  llvm::SmallVector<const Type *> TypesToScan;
-  TypesToScan.emplace_back(RT);
-
-  while (!TypesToScan.empty()) {
-    const Type *T = TypesToScan.pop_back_val();
-    while (T->isArrayType())
-      T = T->getArrayElementTypeNoTypeQual();
-    if (T->isIntegralOrEnumerationType() || T->isFloatingType()) {
-      if (RegType == RegisterType::C)
-        return true;
+// Walks though the global variable declaration, collects all resource binding
+// requirements and adds them to Bindings
+void SemaHLSL::collectResourcesOnUserRecordDecl(const VarDecl *VD,
+                                                const RecordType *RT) {
+  const RecordDecl *RD = RT->getDecl();
+  for (FieldDecl *FD : RD->fields()) {
+    const Type *Ty = FD->getType()->getUnqualifiedDesugaredType();
+
+    // Unwrap arrays
+    // FIXME: Calculate array size while unwrapping
+    assert(!Ty->isIncompleteArrayType() &&
+           "incomplete arrays inside user defined types are not supported");
+    while (Ty->isConstantArrayType()) {
+      const ConstantArrayType *CAT = cast<ConstantArrayType>(Ty);
+      Ty = CAT->getElementType()->getUnqualifiedDesugaredType();
     }
-    const RecordType *RT = T->getAs<RecordType>();
-    if (!RT)
+
+    if (!Ty->isRecordType())
       continue;
 
-    const RecordDecl *RD = RT->getDecl();
-    for (FieldDecl *FD : RD->fields()) {
-      const Type *FieldTy = FD->getType().getTypePtr();
-      if (const HLSLAttributedResourceType *AttrResType =
-              dyn_cast<HLSLAttributedResourceType>(FieldTy)) {
-        ResourceClass RC = AttrResType->getAttrs().ResourceClass;
-        if (getRegisterType(RC) == RegType)
-          return true;
-      } else {
-        TypesToScan.emplace_back(FD->getType().getTypePtr());
-      }
+    if (const HLSLAttributedResourceType *AttrResType =
+            findHandleTypeOnResource(Ty)) {
+      // Add a new DeclBindingInfo to Bindings if it does not already exist
+      ResourceClass RC = AttrResType->getAttrs().ResourceClass;
+      DeclBindingInfo *DBI = Bindings.getDeclBindingInfo(VD, RC);
+      if (!DBI)
+        Bindings.addDeclBindingInfo(VD, RC);
+    } else if (const RecordType *RT = dyn_cast<RecordType>(Ty)) {
+      // Recursively scan embedded struct or class; it would be nice to do this
+      // without recursion, but tricky to correctly calculate the size of the
+      // binding, which is something we are probably going to need to do later
+      // on. Hopefully nesting of structs in structs too many levels is
+      // unlikely.
+      collectResourcesOnUserRecordDecl(VD, RT);
     }
   }
-  return false;
 }
 
-static void CheckContainsResourceForRegisterType(Sema &S,
-                                                 SourceLocation &ArgLoc,
-                                                 Decl *D, RegisterType RegType,
-                                                 bool SpecifiedSpace) {
+// Diagnore localized register binding errors for a single binding; does not
+// diagnose resource binding on user record types, that will be done later
+// in processResourceBindingOnDecl based on the information collected in
+// collectResourcesOnVarDecl.
+// Returns false if the register binding is not valid.
+static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc,
+                                         Decl *D, RegisterType RegType,
+                                         bool SpecifiedSpace) {
   int RegTypeNum = static_cast<int>(RegType);
 
   // check if the decl type is groupshared
   if (D->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) {
     S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
-    return;
+    return false;
   }
 
   // Cbuffers and Tbuffers are HLSLBufferDecl types
   if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) {
     ResourceClass RC = CBufferOrTBuffer->isCBuffer() ? ResourceClass::CBuffer
                                                      : ResourceClass::SRV;
-    if (RegType != getRegisterType(RC))
-      S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
-          << RegTypeNum;
-    return;
+    if (RegType == getRegisterType(RC))
+      return true;
+
+    S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
+        << RegTypeNum;
+    return false;
   }
 
   // Samplers, UAVs, and SRVs are VarDecl types
@@ -1075,11 +1134,13 @@ static void CheckContainsResourceForRegisterType(Sema &S,
 
   // Resource
   if (const HLSLAttributedResourceType *AttrResType =
-          findAttributedResourceTypeOnField(VD)) {
-    if (RegType != getRegisterType(AttrResType->getAttrs().ResourceClass))
-      S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
-          << RegTypeNum;
-    return;
+          findHandleTypeOnResource(VD->getType().getTypePtr())) {
+    if (RegType == getRegisterType(AttrResType->getAttrs().ResourceClass))
+      return true;
+
+    S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
+        << RegTypeNum;
+    return false;
   }
 
   const clang::Type *Ty = VD->getType().getTypePtr();
@@ -1105,51 +1166,44 @@ static void CheckContainsResourceForRegisterType(Sema &S,
       else
         S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
     }
-  } else if (Ty->isRecordType()) {
-    // Class/struct types - walk the declaration and check each field and
-    // subclass
-    if (!ContainsResourceForRegisterType(S, Ty->getAs<RecordType>(), RegType))
-      S.Diag(D->getLocation(), diag::warn_hlsl_user_defined_type_missing_member)
-          << RegTypeNum;
-  } else {
-    // Anything else is an error
-    S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
+    return false;
   }
+  if (Ty->isRecordType())
+    // RecordTypes will be diagnosed in processResourceBindingOnDecl
+    // that is called from ActOnVariableDeclarator
+    return true;
+
+  // Anything else is an error
+  S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
+  return false;
 }
 
-static void ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
+static bool ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
                                                 RegisterType regType) {
   // make sure that there are no two register annotations
   // applied to the decl with the same register type
   bool RegisterTypesDetected[5] = {false};
-
   RegisterTypesDetected[static_cast<int>(regType)] = true;
 
-  // we need a static map to keep track of previous conflicts
-  // so that we don't emit the same error multiple times
-  static std::map<Decl *, std::set<RegisterType>> PreviousConflicts;
-
   for (auto it = TheDecl->attr_begin(); it != TheDecl->attr_end(); ++it) {
     if (HLSLResourceBindingAttr *attr =
             dyn_cast<HLSLResourceBindingAttr>(*it)) {
 
-      RegisterType otherRegType = getRegisterType(attr->getSlot());
+      RegisterType otherRegType = attr->getRegisterType();
       if (RegisterTypesDetected[static_cast<int>(otherRegType)]) {
-        if (PreviousConflicts[TheDecl].count(otherRegType))
-          continue;
         int otherRegTypeNum = static_cast<int>(otherRegType);
         S.Diag(TheDecl->getLocation(),
                diag::err_hlsl_duplicate_register_annotation)
             << otherRegTypeNum;
-        PreviousConflicts[TheDecl].insert(otherRegType);
-      } else {
-        RegisterTypesDetected[static_cast<int>(otherRegType)] = true;
+        return false;
       }
+      RegisterTypesDetected[static_cast<int>(otherRegType)] = true;
     }
   }
+  return true;
 }
 
-static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
+static bool DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
                                           Decl *D, RegisterType RegType,
                                           bool SpecifiedSpace) {
 
@@ -1159,10 +1213,11 @@ static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
          "expecting VarDecl or HLSLBufferDecl");
 
   // check if the declaration contains resource matching the register type
-  CheckContainsResourceForRegisterType(S, ArgLoc, D, RegType, SpecifiedSpace);
+  if (!DiagnoseLocalRegisterBinding(S, ArgLoc, D, RegType, SpecifiedSpace))
+    return false;
 
   // next, if multiple register annotations exist, check that none conflict.
-  ValidateMultipleRegisterAnnotations(S, D, RegType);
+  return ValidateMultipleRegisterAnnotations(S, D, RegType);
 }
 
 void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {
@@ -1203,23 +1258,23 @@ void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {
     Slot = Str;
   }
 
-  RegisterType regType;
+  RegisterType RegType;
+  unsigned SlotNum = 0;
+  unsigned SpaceNum = 0;
 
   // Validate.
   if (!Slot.empty()) {
-    regType = getRegisterType(Slot);
-    if (regType == RegisterType::I) {
-      Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_i);
+    if (!convertToRegisterType(Slot, &RegType)) {
+      Diag(ArgLoc, diag::err_hlsl_binding_type_invalid) << Slot.substr(0, 1);
       return;
     }
-    if (regType == RegisterType::Invalid) {
-      Diag(ArgLoc, diag::err_hlsl_binding_type_invalid) << Slot.substr(0, 1);
+    if (RegType == RegisterType::I) {
+      Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_i);
       return;
     }
 
-    StringRef SlotNum = Slot.substr(1);
-    unsigned Num = 0;
-    if (SlotNum.getAsInteger(10, Num)) {
+    StringRef SlotNumStr = Slot.substr(1);
+    if (SlotNumStr.getAsInteger(10, SlotNum)) {
       Diag(ArgLoc, diag::err_hlsl_unsupported_register_number);
       return;
     }
@@ -1229,20 +1284,22 @@ void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {
     Diag(SpaceArgLoc, diag::err_hlsl_expected_space) << Space;
     return;
   }
-  StringRef SpaceNum = Space.substr(5);
-  unsigned Num = 0;
-  if (SpaceNum.getAsInteger(10, Num)) {
+  StringRef SpaceNumStr = Space.substr(5);
+  if (SpaceNumStr.getAsInteger(10, SpaceNum)) {
     Diag(SpaceArgLoc, diag::err_hlsl_expected_space) << Space;
     return;
   }
 
-  DiagnoseHLSLRegisterAttribute(SemaRef, ArgLoc, TheDecl, regType,
-                                SpecifiedSpace);
+  if (!DiagnoseHLSLRegisterAttribute(SemaRef, ArgLoc, TheDecl, RegType,
+                                     SpecifiedSpace))
+    return;
 
   HLSLResourceBindingAttr *NewAttr =
       HLSLResourceBindingAttr::Create(getASTContext(), Slot, Space, AL);
-  if (NewAttr)
+  if (NewAttr) {
+    NewAttr->setBinding(RegType, SlotNum, SpaceNum);
     TheDecl->addAttr(NewAttr);
+  }
 }
 
 void SemaHLSL::handleParamModifierAttr(Decl *D, const ParsedAttr &AL) {
@@ -2089,6 +2146,7 @@ bool SemaHLSL::IsIntangibleType(clang::QualType QT) {
   CXXRecordDecl *RD = RT->getAsCXXRecordDecl();
   assert(RD != nullptr &&
          "all HLSL struct and classes should be CXXRecordDecl");
+  assert(RD->isCompleteDefinition() && "expecting complete type");
   return RD->isHLSLIntangible();
 }
 
@@ -2274,3 +2332,95 @@ QualType SemaHLSL::getInoutParameterType(QualType Ty) {
   Ty.addRestrict();
   return Ty;
 }
+
+void SemaHLSL::ActOnVariableDeclarator(VarDecl *VD) {
+  if (VD->hasGlobalStorage()) {
+    // make sure the declaration has a complete type
+    if (SemaRef.RequireCompleteType(
+            VD->getLocation(),
+            SemaRef.getASTContext().getBaseElementType(VD->getType()),
+            diag::err_typecheck_decl_incomplete_type)) {
+      VD->setInvalidDecl();
+      return;
+    }
+
+    // find all resources on decl
+    if (IsIntangibleType(VD->getType()))
+      collectResourcesOnVarDecl(VD);
+
+    // process explicit bindings
+    processExplicitBindingsOnDecl(VD);
+  }
+}
+
+// Walks though the global variable declaration, collects all resource binding
+// requirements and adds them to Bindings
+void SemaHLSL::collectResourcesOnVarDecl(VarDecl *VD) {
+  assert(VD->hasGlobalStorage() && IsIntangibleType(VD->getType()) &&
+         "expected global variable that contains HLSL resource");
+
+  // Cbuffers and Tbuffers are HLSLBufferDecl types
+  if (const HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(VD)) {
+    Bindings.addDeclBindingInfo(VD, CBufferOrTBuffer->isCBuffer()
+                                        ? ResourceClass::CBuffer
+                                        : ResourceClass::SRV);
+    return;
+  }
+
+  // Unwrap arrays
+  // FIXME: Calculate array size while unwrapping
+  const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
+  while (Ty->isConstantArrayType()) {
+    const ConstantArrayType *CAT = cast<ConstantArrayType>(Ty);
+    Ty = CAT->getElementType()->getUnqualifiedDesugaredType();
+  }
+
+  // Resource (or array of resources)
+  if (const HLSLAttributedResourceType *AttrResType =
+          findHandleTypeOnResource(Ty)) {
+    Bindings.addDeclBindingInfo(VD, AttrResType->getAttrs().ResourceClass);
+    return;
+  }
+
+  // User defined record type
+  if (const RecordType *RT = dyn_cast<RecordType>(Ty))
+    collectResourcesOnUserRecordDecl(VD, RT);
+}
+
+// Walks though the explicit resource binding attributes on the declaration,
+// and makes sure there is a resource that matched the binding and updates
+// DeclBindingInfoLists
+void SemaHLSL::processExplicitBindingsOnDecl(VarDecl *VD) {
+  assert(VD->hasGlobalStorage() && "expected global variable");
+
+  for (Attr *A : VD->attrs()) {
+    HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(A);
+    if (!RBA)
+      continue;
+
+    RegisterType RT = RBA->getRegisterType();
+    assert(RT != RegisterType::I && "invalid or obsolete register type should "
+                                    "never have an attribute created");
+
+    if (RT == RegisterType::C) {
+      if (Bindings.hasBindingInfoForDecl(VD))
+        SemaRef.Diag(VD->getLocation(),
+                     diag::warn_hlsl_user_defined_type_missing_member)
+            << static_cast<int>(RT);
+      continue;
+    }
+
+    // Find DeclBindingInfo for this binding and update it, or report error
+    // if it does not exist (user type does to contain resources with the
+    // expected resource class).
+    ResourceClass RC = getResourceClass(RT);
+    if (DeclBindingInfo *BI = Bindings.getDeclBindingInfo(VD, RC)) {
+      // update binding info
+      BI->setBindingAttribute(RBA, BindingType::Explicit);
+    } else {
+      SemaRef.Diag(VD->getLocation(),
+                   diag::warn_hlsl_user_defined_type_missing_member)
+          << static_cast<int>(RT);
+    }
+  }
+}

diff  --git a/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl b/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl
index ea2d576e4cca55..40517f393e1284 100644
--- a/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl
+++ b/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl
@@ -106,7 +106,6 @@ struct Eg12{
   MySRV s1;
   MySRV s2;
 };
-// expected-warning at +3{{binding type 'u' only applies to types containing UAV resources}}
 // expected-warning at +2{{binding type 'u' only applies to types containing UAV resources}}
 // expected-error at +1{{binding type 'u' cannot be applied more than once}}
 Eg12 e12 : register(u9) : register(u10);
@@ -115,12 +114,14 @@ struct Eg13{
   MySRV s1;
   MySRV s2;
 };
-// expected-warning at +4{{binding type 'u' only applies to types containing UAV resources}}
 // expected-warning at +3{{binding type 'u' only applies to types containing UAV resources}}
-// expected-warning at +2{{binding type 'u' only applies to types containing UAV resources}}
+// expected-error at +2{{binding type 'u' cannot be applied more than once}}
 // expected-error at +1{{binding type 'u' cannot be applied more than once}}
 Eg13 e13 : register(u9) : register(u10) : register(u11);
 
+// expected-error at +1{{binding type 't' cannot be applied more than once}}
+Eg13 e13_2 : register(t11) : register(t12);
+
 struct Eg14{
  MyTemplatedUAV<int> r1;  
 };
@@ -132,4 +133,3 @@ struct Eg15 {
 }; 
 // expected no error
 Eg15 e15 : register(c0);
-


        


More information about the cfe-commits mailing list