[clang] [HLSL][NFC] Remove RegisterBindingFlags struct (PR #108924)

Helena Kotas via cfe-commits cfe-commits at lists.llvm.org
Tue Sep 17 22:12:04 PDT 2024


https://github.com/hekota updated https://github.com/llvm/llvm-project/pull/108924

>From 1dd552dfb6217804ba5e84a35e59e348622df581 Mon Sep 17 00:00:00 2001
From: Helena Kotas <hekotas at microsoft.com>
Date: Mon, 16 Sep 2024 20:54:23 -0700
Subject: [PATCH 1/5] [HLSL][NFC] Remove RegisterBindingFlags struct

When diagnosing register bindings we just need to make sure there is
a resource that matches the provided register type. We can emit the
diagnostics right away instead of collecting flags in the
RegisterBindingFlags struct. That also enables early exit when scanning
user defined types because we can return as soon as we find a matching
resource for the given register type.
---
 clang/lib/Sema/SemaHLSL.cpp | 283 +++++++++++++-----------------------
 1 file changed, 101 insertions(+), 182 deletions(-)

diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 26de9a986257c5..9cc1860e52bd2c 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -41,6 +41,47 @@
 
 using namespace clang;
 
+enum class RegisterType { SRV, UAV, CBuffer, Sampler, C, I, Invalid };
+
+static RegisterType getRegisterType(llvm::dxil::ResourceClass RC) {
+  switch (RC) {
+  case llvm::dxil::ResourceClass::SRV:
+    return RegisterType::SRV;
+  case llvm::dxil::ResourceClass::UAV:
+    return RegisterType::UAV;
+  case llvm::dxil::ResourceClass::CBuffer:
+    return RegisterType::CBuffer;
+  case llvm::dxil::ResourceClass::Sampler:
+    return RegisterType::Sampler;
+  }
+  llvm_unreachable("unexpected ResourceClass value");
+}
+
+static RegisterType getRegisterType(StringRef Slot) {
+  switch (Slot[0]) {
+  case 't':
+  case 'T':
+    return RegisterType::SRV;
+  case 'u':
+  case 'U':
+    return RegisterType::UAV;
+  case 'b':
+  case 'B':
+    return RegisterType::CBuffer;
+  case 's':
+  case 'S':
+    return RegisterType::Sampler;
+  case 'c':
+  case 'C':
+    return RegisterType::C;
+  case 'i':
+  case 'I':
+    return RegisterType::I;
+  default:
+    return RegisterType::Invalid;
+  }
+}
+
 SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S) {}
 
 Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer,
@@ -739,28 +780,6 @@ SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) {
   return LocInfo;
 }
 
-struct RegisterBindingFlags {
-  bool Resource = false;
-  bool UDT = false;
-  bool Other = false;
-  bool Basic = false;
-
-  bool SRV = false;
-  bool UAV = false;
-  bool CBV = false;
-  bool Sampler = false;
-
-  bool ContainsNumeric = false;
-  bool DefaultGlobals = false;
-
-  // used only when Resource == true
-  std::optional<llvm::dxil::ResourceClass> ResourceClass;
-};
-
-static bool isDeclaredWithinCOrTBuffer(const Decl *TheDecl) {
-  return TheDecl && isa<HLSLBufferDecl>(TheDecl->getDeclContext());
-}
-
 // get the record decl from a var decl that we expect
 // represents a resource
 static CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *VD) {
@@ -775,24 +794,6 @@ static CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *VD) {
   return TheRecordDecl;
 }
 
-static void updateResourceClassFlagsFromDeclResourceClass(
-    RegisterBindingFlags &Flags, llvm::hlsl::ResourceClass DeclResourceClass) {
-  switch (DeclResourceClass) {
-  case llvm::hlsl::ResourceClass::SRV:
-    Flags.SRV = true;
-    break;
-  case llvm::hlsl::ResourceClass::UAV:
-    Flags.UAV = true;
-    break;
-  case llvm::hlsl::ResourceClass::CBuffer:
-    Flags.CBV = true;
-    break;
-  case llvm::hlsl::ResourceClass::Sampler:
-    Flags.Sampler = true;
-    break;
-  }
-}
-
 const HLSLAttributedResourceType *
 findAttributedResourceTypeOnField(VarDecl *VD) {
   assert(VD != nullptr && "expected VarDecl");
@@ -806,8 +807,8 @@ findAttributedResourceTypeOnField(VarDecl *VD) {
   return nullptr;
 }
 
-static void updateResourceClassFlagsFromRecordType(RegisterBindingFlags &Flags,
-                                                   const RecordType *RT) {
+// Iterate over RecordType fields and return true if any of them matched the register type
+static bool ContainsResourceForRegisterType(Sema &S, const RecordType *RT, RegisterType RegType) {
   llvm::SmallVector<const Type *> TypesToScan;
   TypesToScan.emplace_back(RT);
 
@@ -816,8 +817,8 @@ static void updateResourceClassFlagsFromRecordType(RegisterBindingFlags &Flags,
     while (T->isArrayType())
       T = T->getArrayElementTypeNoTypeQual();
     if (T->isIntegralOrEnumerationType() || T->isFloatingType()) {
-      Flags.ContainsNumeric = true;
-      continue;
+      if (RegType == RegisterType::C)
+        return true;
     }
     const RecordType *RT = T->getAs<RecordType>();
     if (!RT)
@@ -828,101 +829,74 @@ static void updateResourceClassFlagsFromRecordType(RegisterBindingFlags &Flags,
       const Type *FieldTy = FD->getType().getTypePtr();
       if (const HLSLAttributedResourceType *AttrResType =
               dyn_cast<HLSLAttributedResourceType>(FieldTy)) {
-        updateResourceClassFlagsFromDeclResourceClass(
-            Flags, AttrResType->getAttrs().ResourceClass);
+        llvm::dxil::ResourceClass RC = AttrResType->getAttrs().ResourceClass;
+        if (getRegisterType(RC) == RegType)
+          return true;
         continue;
       }
       TypesToScan.emplace_back(FD->getType().getTypePtr());
     }
   }
+  return false;
 }
 
-static RegisterBindingFlags HLSLFillRegisterBindingFlags(Sema &S,
-                                                         Decl *TheDecl) {
-  RegisterBindingFlags Flags;
+static void CheckContainsResourceForRegisterType(Sema &S, SourceLocation &ArgLoc, Decl *D, RegisterType RegType) {
+  int RegTypeNum = static_cast<int>(RegType);
 
   // check if the decl type is groupshared
-  if (TheDecl->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) {
-    Flags.Other = true;
-    return Flags;
+  if (D->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) {
+    S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
+    return;
   }
 
   // Cbuffers and Tbuffers are HLSLBufferDecl types
-  if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl)) {
-    Flags.Resource = true;
-    Flags.ResourceClass = CBufferOrTBuffer->isCBuffer()
-                              ? llvm::dxil::ResourceClass::CBuffer
-                              : llvm::dxil::ResourceClass::SRV;
+  if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) {
+    llvm::dxil::ResourceClass RC = CBufferOrTBuffer->isCBuffer() ? llvm::dxil::ResourceClass::CBuffer : llvm::dxil::ResourceClass::SRV;
+    if (RegType != getRegisterType(RC))
+      S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
+    return;
   }
   // Samplers, UAVs, and SRVs are VarDecl types
-  else if (VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl)) {
+  if (VarDecl *TheVarDecl = dyn_cast<VarDecl>(D)) {
+    // Resource
     if (const HLSLAttributedResourceType *AttrResType =
             findAttributedResourceTypeOnField(TheVarDecl)) {
-      Flags.Resource = true;
-      Flags.ResourceClass = AttrResType->getAttrs().ResourceClass;
+      if (RegType != getRegisterType(AttrResType->getAttrs().ResourceClass))
+        S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
+      return;
+    } 
+    
+    const clang::Type *TheBaseType = TheVarDecl->getType().getTypePtr();
+    while (TheBaseType->isArrayType())
+      TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual();
+
+    // Basic types
+    if (TheBaseType->isArithmeticType()) {
+      if (!isa<HLSLBufferDecl>(D->getDeclContext()) &&
+          (TheBaseType->isIntegralType(S.getASTContext()) ||
+            TheBaseType->isFloatingType())) {
+        // Default Globals
+        if (RegType == RegisterType::CBuffer)
+          S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b);
+        else if (RegType != RegisterType::C)
+          S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
+      } else {
+        if (RegType == RegisterType::C)
+          S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset);
+        else
+          S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
+      }
+    } else if (TheBaseType->isRecordType()) {
+      // Class/struct types - walk the declaration and check each field and subclass
+      if (!ContainsResourceForRegisterType(S, TheBaseType->getAs<RecordType>(), RegType))
+            S.Diag(D->getLocation(), diag::warn_hlsl_user_defined_type_missing_member) << RegTypeNum;
     } else {
-      const clang::Type *TheBaseType = TheVarDecl->getType().getTypePtr();
-      while (TheBaseType->isArrayType())
-        TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual();
-
-      if (TheBaseType->isArithmeticType()) {
-        Flags.Basic = true;
-        if (!isDeclaredWithinCOrTBuffer(TheDecl) &&
-            (TheBaseType->isIntegralType(S.getASTContext()) ||
-             TheBaseType->isFloatingType()))
-          Flags.DefaultGlobals = true;
-      } else if (TheBaseType->isRecordType()) {
-        Flags.UDT = true;
-        const RecordType *TheRecordTy = TheBaseType->getAs<RecordType>();
-        updateResourceClassFlagsFromRecordType(Flags, TheRecordTy);
-      } else
-        Flags.Other = true;
+      // Anything else is an error
+      S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
     }
-  } else {
-    llvm_unreachable("expected be VarDecl or HLSLBufferDecl");
-  }
-  return Flags;
-}
-
-enum class RegisterType { SRV, UAV, CBuffer, Sampler, C, I, Invalid };
-
-static RegisterType getRegisterType(llvm::dxil::ResourceClass RC) {
-  switch (RC) {
-  case llvm::dxil::ResourceClass::SRV:
-    return RegisterType::SRV;
-  case llvm::dxil::ResourceClass::UAV:
-    return RegisterType::UAV;
-  case llvm::dxil::ResourceClass::CBuffer:
-    return RegisterType::CBuffer;
-  case llvm::dxil::ResourceClass::Sampler:
-    return RegisterType::Sampler;
-  }
-  llvm_unreachable("unexpected ResourceClass value");
-}
-
-static RegisterType getRegisterType(StringRef Slot) {
-  switch (Slot[0]) {
-  case 't':
-  case 'T':
-    return RegisterType::SRV;
-  case 'u':
-  case 'U':
-    return RegisterType::UAV;
-  case 'b':
-  case 'B':
-    return RegisterType::CBuffer;
-  case 's':
-  case 'S':
-    return RegisterType::Sampler;
-  case 'c':
-  case 'C':
-    return RegisterType::C;
-  case 'i':
-  case 'I':
-    return RegisterType::I;
-  default:
-    return RegisterType::Invalid;
+    return;
   }
+  llvm_unreachable("expected be VarDecl or HLSLBufferDecl");
 }
 
 static void ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
@@ -958,73 +932,18 @@ static void ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
 }
 
 static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
-                                          Decl *TheDecl, RegisterType regType) {
+                                          Decl *D, RegisterType RegType) {
 
   // exactly one of these two types should be set
-  assert(((isa<VarDecl>(TheDecl) && !isa<HLSLBufferDecl>(TheDecl)) ||
-          (!isa<VarDecl>(TheDecl) && isa<HLSLBufferDecl>(TheDecl))) &&
+  assert(((isa<VarDecl>(D) && !isa<HLSLBufferDecl>(D)) ||
+          (!isa<VarDecl>(D) && isa<HLSLBufferDecl>(D))) &&
          "expecting VarDecl or HLSLBufferDecl");
 
-  RegisterBindingFlags Flags = HLSLFillRegisterBindingFlags(S, TheDecl);
-  assert((int)Flags.Other + (int)Flags.Resource + (int)Flags.Basic +
-                 (int)Flags.UDT ==
-             1 &&
-         "only one resource analysis result should be expected");
-
-  int regTypeNum = static_cast<int>(regType);
-
-  // first, if "other" is set, emit an error
-  if (Flags.Other) {
-    S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << regTypeNum;
-    return;
-  }
+  // check if the declaration contains resource matching the register type
+  CheckContainsResourceForRegisterType(S, ArgLoc, D, RegType);
 
-  // next, if multiple register annotations exist, check that none conflict.
-  ValidateMultipleRegisterAnnotations(S, TheDecl, regType);
-
-  // next, if resource is set, make sure the register type in the register
-  // annotation is compatible with the variable's resource type.
-  if (Flags.Resource) {
-    RegisterType expRegType = getRegisterType(Flags.ResourceClass.value());
-    if (regType != expRegType) {
-      S.Diag(TheDecl->getLocation(), diag::err_hlsl_binding_type_mismatch)
-          << regTypeNum;
-    }
-    return;
-  }
-
-  // next, handle diagnostics for when the "basic" flag is set
-  if (Flags.Basic) {
-    if (Flags.DefaultGlobals) {
-      if (regType == RegisterType::CBuffer)
-        S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b);
-      else if (regType != RegisterType::C)
-        S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << regTypeNum;
-      return;
-    }
-
-    if (regType == RegisterType::C)
-      S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset);
-    else
-      S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << regTypeNum;
-
-    return;
-  }
-
-  // finally, we handle the udt case
-  if (Flags.UDT) {
-    const bool ExpectedRegisterTypesForUDT[] = {
-        Flags.SRV, Flags.UAV, Flags.CBV, Flags.Sampler, Flags.ContainsNumeric};
-    assert((size_t)regTypeNum < std::size(ExpectedRegisterTypesForUDT) &&
-           "regType has unexpected value");
-
-    if (!ExpectedRegisterTypesForUDT[regTypeNum])
-      S.Diag(TheDecl->getLocation(),
-             diag::warn_hlsl_user_defined_type_missing_member)
-          << regTypeNum;
-
-    return;
-  }
+  // check multiple register annotations
+  ValidateMultipleRegisterAnnotations(S, D, RegType);
 }
 
 void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {

>From 993f1d066e3f114ea0bd028e8f1840ff1d07fa26 Mon Sep 17 00:00:00 2001
From: Helena Kotas <hekotas at microsoft.com>
Date: Mon, 16 Sep 2024 20:57:15 -0700
Subject: [PATCH 2/5] revert comment change

---
 clang/lib/Sema/SemaHLSL.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 9cc1860e52bd2c..8883652947d504 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -942,7 +942,7 @@ static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
   // check if the declaration contains resource matching the register type
   CheckContainsResourceForRegisterType(S, ArgLoc, D, RegType);
 
-  // check multiple register annotations
+  // next, if multiple register annotations exist, check that none conflict.
   ValidateMultipleRegisterAnnotations(S, D, RegType);
 }
 

>From dfd351079fd369472ddf63d81252c7152be89b2a Mon Sep 17 00:00:00 2001
From: Helena Kotas <hekotas at microsoft.com>
Date: Mon, 16 Sep 2024 21:02:49 -0700
Subject: [PATCH 3/5] clang-format

---
 clang/lib/Sema/SemaHLSL.cpp | 37 +++++++++++++++++++++++++------------
 1 file changed, 25 insertions(+), 12 deletions(-)

diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 8883652947d504..45a8d75ce099cc 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -807,8 +807,10 @@ findAttributedResourceTypeOnField(VarDecl *VD) {
   return nullptr;
 }
 
-// Iterate over RecordType fields and return true if any of them matched the register type
-static bool ContainsResourceForRegisterType(Sema &S, const RecordType *RT, RegisterType RegType) {
+// Iterate over RecordType fields and return true if any of them matched the
+// register type
+static bool ContainsResourceForRegisterType(Sema &S, const RecordType *RT,
+                                            RegisterType RegType) {
   llvm::SmallVector<const Type *> TypesToScan;
   TypesToScan.emplace_back(RT);
 
@@ -840,7 +842,10 @@ static bool ContainsResourceForRegisterType(Sema &S, const RecordType *RT, Regis
   return false;
 }
 
-static void CheckContainsResourceForRegisterType(Sema &S, SourceLocation &ArgLoc, Decl *D, RegisterType RegType) {
+static void CheckContainsResourceForRegisterType(Sema &S,
+                                                 SourceLocation &ArgLoc,
+                                                 Decl *D,
+                                                 RegisterType RegType) {
   int RegTypeNum = static_cast<int>(RegType);
 
   // check if the decl type is groupshared
@@ -851,9 +856,12 @@ static void CheckContainsResourceForRegisterType(Sema &S, SourceLocation &ArgLoc
 
   // Cbuffers and Tbuffers are HLSLBufferDecl types
   if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) {
-    llvm::dxil::ResourceClass RC = CBufferOrTBuffer->isCBuffer() ? llvm::dxil::ResourceClass::CBuffer : llvm::dxil::ResourceClass::SRV;
+    llvm::dxil::ResourceClass RC = CBufferOrTBuffer->isCBuffer()
+                                       ? llvm::dxil::ResourceClass::CBuffer
+                                       : llvm::dxil::ResourceClass::SRV;
     if (RegType != getRegisterType(RC))
-      S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
+      S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
+          << RegTypeNum;
     return;
   }
   // Samplers, UAVs, and SRVs are VarDecl types
@@ -862,10 +870,11 @@ static void CheckContainsResourceForRegisterType(Sema &S, SourceLocation &ArgLoc
     if (const HLSLAttributedResourceType *AttrResType =
             findAttributedResourceTypeOnField(TheVarDecl)) {
       if (RegType != getRegisterType(AttrResType->getAttrs().ResourceClass))
-        S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
+        S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
+            << RegTypeNum;
       return;
-    } 
-    
+    }
+
     const clang::Type *TheBaseType = TheVarDecl->getType().getTypePtr();
     while (TheBaseType->isArrayType())
       TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual();
@@ -874,7 +883,7 @@ static void CheckContainsResourceForRegisterType(Sema &S, SourceLocation &ArgLoc
     if (TheBaseType->isArithmeticType()) {
       if (!isa<HLSLBufferDecl>(D->getDeclContext()) &&
           (TheBaseType->isIntegralType(S.getASTContext()) ||
-            TheBaseType->isFloatingType())) {
+           TheBaseType->isFloatingType())) {
         // Default Globals
         if (RegType == RegisterType::CBuffer)
           S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b);
@@ -887,9 +896,13 @@ static void CheckContainsResourceForRegisterType(Sema &S, SourceLocation &ArgLoc
           S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
       }
     } else if (TheBaseType->isRecordType()) {
-      // Class/struct types - walk the declaration and check each field and subclass
-      if (!ContainsResourceForRegisterType(S, TheBaseType->getAs<RecordType>(), RegType))
-            S.Diag(D->getLocation(), diag::warn_hlsl_user_defined_type_missing_member) << RegTypeNum;
+      // Class/struct types - walk the declaration and check each field and
+      // subclass
+      if (!ContainsResourceForRegisterType(S, TheBaseType->getAs<RecordType>(),
+                                           RegType))
+        S.Diag(D->getLocation(),
+               diag::warn_hlsl_user_defined_type_missing_member)
+            << RegTypeNum;
     } else {
       // Anything else is an error
       S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;

>From 69fe0a05a8b134853e2b5daa716c20636d1e74ca Mon Sep 17 00:00:00 2001
From: Helena Kotas <hekotas at microsoft.com>
Date: Tue, 17 Sep 2024 10:52:40 -0700
Subject: [PATCH 4/5] Code review feedback

+ couple of local variables names shortened for easier debugging when manually invoking dump() in debugger
---
 clang/lib/Sema/SemaHLSL.cpp | 106 +++++++++++++++++-------------------
 1 file changed, 51 insertions(+), 55 deletions(-)

diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 45a8d75ce099cc..67fb86d2cec0bd 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -40,18 +40,19 @@
 #include <utility>
 
 using namespace clang;
+using llvm::dxil::ResourceClass;
 
 enum class RegisterType { SRV, UAV, CBuffer, Sampler, C, I, Invalid };
 
-static RegisterType getRegisterType(llvm::dxil::ResourceClass RC) {
+static RegisterType getRegisterType(ResourceClass RC) {
   switch (RC) {
-  case llvm::dxil::ResourceClass::SRV:
+  case ResourceClass::SRV:
     return RegisterType::SRV;
-  case llvm::dxil::ResourceClass::UAV:
+  case ResourceClass::UAV:
     return RegisterType::UAV;
-  case llvm::dxil::ResourceClass::CBuffer:
+  case ResourceClass::CBuffer:
     return RegisterType::CBuffer;
-  case llvm::dxil::ResourceClass::Sampler:
+  case ResourceClass::Sampler:
     return RegisterType::Sampler;
   }
   llvm_unreachable("unexpected ResourceClass value");
@@ -627,8 +628,7 @@ bool clang::CreateHLSLAttributedResourceType(
     LocEnd = A->getRange().getEnd();
     switch (A->getKind()) {
     case attr::HLSLResourceClass: {
-      llvm::dxil::ResourceClass RC =
-          cast<HLSLResourceClassAttr>(A)->getResourceClass();
+      ResourceClass RC = cast<HLSLResourceClassAttr>(A)->getResourceClass();
       if (HasResourceClass) {
         S.Diag(A->getLocation(), ResAttrs.ResourceClass == RC
                                      ? diag::warn_duplicate_attribute_exact
@@ -706,7 +706,7 @@ bool SemaHLSL::handleResourceTypeAttr(const ParsedAttr &AL) {
     SourceLocation ArgLoc = Loc->Loc;
 
     // Validate resource class value
-    llvm::dxil::ResourceClass RC;
+    ResourceClass RC;
     if (!HLSLResourceClassAttr::ConvertStrToResourceClass(Identifier, RC)) {
       Diag(ArgLoc, diag::warn_attribute_type_not_supported)
           << "ResourceClass" << Identifier;
@@ -831,12 +831,12 @@ static bool ContainsResourceForRegisterType(Sema &S, const RecordType *RT,
       const Type *FieldTy = FD->getType().getTypePtr();
       if (const HLSLAttributedResourceType *AttrResType =
               dyn_cast<HLSLAttributedResourceType>(FieldTy)) {
-        llvm::dxil::ResourceClass RC = AttrResType->getAttrs().ResourceClass;
+        ResourceClass RC = AttrResType->getAttrs().ResourceClass;
         if (getRegisterType(RC) == RegType)
           return true;
-        continue;
+      } else {
+        TypesToScan.emplace_back(FD->getType().getTypePtr());
       }
-      TypesToScan.emplace_back(FD->getType().getTypePtr());
     }
   }
   return false;
@@ -856,60 +856,56 @@ static void CheckContainsResourceForRegisterType(Sema &S,
 
   // Cbuffers and Tbuffers are HLSLBufferDecl types
   if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) {
-    llvm::dxil::ResourceClass RC = CBufferOrTBuffer->isCBuffer()
-                                       ? llvm::dxil::ResourceClass::CBuffer
-                                       : llvm::dxil::ResourceClass::SRV;
+    ResourceClass RC = CBufferOrTBuffer->isCBuffer() ? ResourceClass::CBuffer
+                                                     : ResourceClass::SRV;
     if (RegType != getRegisterType(RC))
       S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
           << RegTypeNum;
     return;
   }
+
   // Samplers, UAVs, and SRVs are VarDecl types
-  if (VarDecl *TheVarDecl = dyn_cast<VarDecl>(D)) {
-    // Resource
-    if (const HLSLAttributedResourceType *AttrResType =
-            findAttributedResourceTypeOnField(TheVarDecl)) {
-      if (RegType != getRegisterType(AttrResType->getAttrs().ResourceClass))
-        S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
-            << RegTypeNum;
-      return;
-    }
+  assert(isa<VarDecl>(D) && "D is expected to be VarDecl or HLSLBufferDecl");
+  VarDecl *VD = cast<VarDecl>(D);
 
-    const clang::Type *TheBaseType = TheVarDecl->getType().getTypePtr();
-    while (TheBaseType->isArrayType())
-      TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual();
-
-    // Basic types
-    if (TheBaseType->isArithmeticType()) {
-      if (!isa<HLSLBufferDecl>(D->getDeclContext()) &&
-          (TheBaseType->isIntegralType(S.getASTContext()) ||
-           TheBaseType->isFloatingType())) {
-        // Default Globals
-        if (RegType == RegisterType::CBuffer)
-          S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b);
-        else if (RegType != RegisterType::C)
-          S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
-      } else {
-        if (RegType == RegisterType::C)
-          S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset);
-        else
-          S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
-      }
-    } else if (TheBaseType->isRecordType()) {
-      // Class/struct types - walk the declaration and check each field and
-      // subclass
-      if (!ContainsResourceForRegisterType(S, TheBaseType->getAs<RecordType>(),
-                                           RegType))
-        S.Diag(D->getLocation(),
-               diag::warn_hlsl_user_defined_type_missing_member)
-            << RegTypeNum;
+  // Resource
+  if (const HLSLAttributedResourceType *AttrResType =
+          findAttributedResourceTypeOnField(VD)) {
+    if (RegType != getRegisterType(AttrResType->getAttrs().ResourceClass))
+      S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
+          << RegTypeNum;
+    return;
+  }
+
+  const clang::Type *Ty = VD->getType().getTypePtr();
+  while (Ty->isArrayType())
+    Ty = Ty->getArrayElementTypeNoTypeQual();
+
+  // Basic types
+  if (Ty->isArithmeticType()) {
+    if (!isa<HLSLBufferDecl>(D->getDeclContext()) &&
+        (Ty->isIntegralType(S.getASTContext()) || Ty->isFloatingType())) {
+      // Default Globals
+      if (RegType == RegisterType::CBuffer)
+        S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b);
+      else if (RegType != RegisterType::C)
+        S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
     } else {
-      // Anything else is an error
-      S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
+      if (RegType == RegisterType::C)
+        S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset);
+      else
+        S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
     }
-    return;
+  } else if (Ty->isRecordType()) {
+    // Class/struct types - walk the declaration and check each field and
+    // subclass
+    if (!ContainsResourceForRegisterType(S, Ty->getAs<RecordType>(), RegType))
+      S.Diag(D->getLocation(), diag::warn_hlsl_user_defined_type_missing_member)
+          << RegTypeNum;
+  } else {
+    // Anything else is an error
+    S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
   }
-  llvm_unreachable("expected be VarDecl or HLSLBufferDecl");
 }
 
 static void ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,

>From 12e8ea6b15378b32ca1474e0da6f4ff3fb003a85 Mon Sep 17 00:00:00 2001
From: Helena Kotas <hekotas at microsoft.com>
Date: Tue, 17 Sep 2024 17:05:00 -0700
Subject: [PATCH 5/5] clang-format

---
 clang/lib/Sema/SemaHLSL.cpp | 10 ++++++----
 1 file changed, 6 insertions(+), 4 deletions(-)

diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 6a48121cea2d42..03b7c2edb605fe 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -855,8 +855,8 @@ static bool ContainsResourceForRegisterType(Sema &S, const RecordType *RT,
 
 static void CheckContainsResourceForRegisterType(Sema &S,
                                                  SourceLocation &ArgLoc,
-                                                 Decl *D,
-                                                 RegisterType RegType, bool SpecifiedSpace) {
+                                                 Decl *D, RegisterType RegType,
+                                                 bool SpecifiedSpace) {
   int RegTypeNum = static_cast<int>(RegType);
 
   // check if the decl type is groupshared
@@ -898,7 +898,8 @@ static void CheckContainsResourceForRegisterType(Sema &S,
     if (SpecifiedSpace && !DeclaredInCOrTBuffer)
       S.Diag(ArgLoc, diag::err_hlsl_space_on_global_constant);
 
-    if (!DeclaredInCOrTBuffer && (Ty->isIntegralType(S.getASTContext()) || Ty->isFloatingType())) {
+    if (!DeclaredInCOrTBuffer &&
+        (Ty->isIntegralType(S.getASTContext()) || Ty->isFloatingType())) {
       // Default Globals
       if (RegType == RegisterType::CBuffer)
         S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b);
@@ -955,7 +956,8 @@ static void ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
 }
 
 static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
-                                          Decl *D, RegisterType RegType, bool SpecifiedSpace) {
+                                          Decl *D, RegisterType RegType,
+                                          bool SpecifiedSpace) {
 
   // exactly one of these two types should be set
   assert(((isa<VarDecl>(D) && !isa<HLSLBufferDecl>(D)) ||



More information about the cfe-commits mailing list