[clang] [HLSL] Collect explicit resource binding information (PR #111203)
Helena Kotas via cfe-commits
cfe-commits at lists.llvm.org
Tue Oct 15 21:51:51 PDT 2024
https://github.com/hekota updated https://github.com/llvm/llvm-project/pull/111203
>From f545a14e11556c91d10b14617e3588fe5eae6d42 Mon Sep 17 00:00:00 2001
From: Helena Kotas <hekotas at microsoft.com>
Date: Fri, 4 Oct 2024 12:21:51 -0700
Subject: [PATCH 1/5] [HLSL] Collect explicit resource binding information
(part 1)
- Do not create resource binding attribute if it is not valid
- Store basic resource binding information on HLSLResourceBindingAttr
- Move UDT type checking to to ActOnVariableDeclarator
Part 1 of #110719
---
clang/include/clang/Basic/Attr.td | 29 +++
clang/include/clang/Sema/SemaHLSL.h | 2 +
clang/lib/Sema/SemaDecl.cpp | 3 +
clang/lib/Sema/SemaHLSL.cpp | 227 ++++++++++++------
.../resource_binding_attr_error_udt.hlsl | 8 +-
5 files changed, 188 insertions(+), 81 deletions(-)
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index fbcbf0ed416416..668c599da81390 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4588,6 +4588,35 @@ def HLSLResourceBinding: InheritableAttr {
let LangOpts = [HLSL];
let Args = [StringArgument<"Slot">, StringArgument<"Space", 1>];
let Documentation = [HLSLResourceBindingDocs];
+ let AdditionalMembers = [{
+ enum class RegisterType : unsigned { SRV, UAV, CBuffer, Sampler, C, I, Invalid };
+
+ const FieldDecl *ResourceField = nullptr;
+ RegisterType RegType;
+ unsigned SlotNumber;
+ unsigned SpaceNumber;
+
+ void setBinding(RegisterType RT, unsigned SlotNum, unsigned SpaceNum) {
+ RegType = RT;
+ SlotNumber = SlotNum;
+ SpaceNumber = SpaceNum;
+ }
+ void setResourceField(const FieldDecl *FD) {
+ ResourceField = FD;
+ }
+ const FieldDecl *getResourceField() {
+ return ResourceField;
+ }
+ RegisterType getRegisterType() {
+ return RegType;
+ }
+ unsigned getSlotNumber() {
+ return SlotNumber;
+ }
+ unsigned getSpaceNumber() {
+ return SpaceNumber;
+ }
+ }];
}
def HLSLPackOffset: HLSLAnnotationAttr {
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index fa957abc9791af..018e7ea5901a2b 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -28,6 +28,7 @@ class AttributeCommonInfo;
class IdentifierInfo;
class ParsedAttr;
class Scope;
+class VarDecl;
// FIXME: This can be hidden (as static function in SemaHLSL.cpp) once we no
// longer need to create builtin buffer types in HLSLExternalSemaSource.
@@ -62,6 +63,7 @@ class SemaHLSL : public SemaBase {
const Attr *A, llvm::Triple::EnvironmentType Stage,
std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages);
void DiagnoseAvailabilityViolations(TranslationUnitDecl *TU);
+ void ProcessResourceBindingOnDecl(VarDecl *D);
QualType handleVectorBinOpConversion(ExprResult &LHS, ExprResult &RHS,
QualType LHSType, QualType RHSType,
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index 2bf610746bc317..8e27a5e068e702 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -7876,6 +7876,9 @@ NamedDecl *Sema::ActOnVariableDeclarator(
// Handle attributes prior to checking for duplicates in MergeVarDecl
ProcessDeclAttributes(S, NewVD, D);
+ if (getLangOpts().HLSL)
+ HLSL().ProcessResourceBindingOnDecl(NewVD);
+
// FIXME: This is probably the wrong location to be doing this and we should
// probably be doing this for more attributes (especially for function
// pointer attributes such as format, warn_unused_result, etc.). Ideally
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index fbcba201a351a6..568a8de30c1fc5 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -41,9 +41,7 @@
using namespace clang;
using llvm::dxil::ResourceClass;
-
-enum class RegisterType { SRV, UAV, CBuffer, Sampler, C, I, Invalid };
-
+using RegisterType = HLSLResourceBindingAttr::RegisterType;
static RegisterType getRegisterType(ResourceClass RC) {
switch (RC) {
case ResourceClass::SRV:
@@ -985,44 +983,43 @@ SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) {
return LocInfo;
}
-// get the record decl from a var decl that we expect
-// represents a resource
-static CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *VD) {
- const Type *Ty = VD->getType()->getPointeeOrArrayElementType();
- assert(Ty && "Resource must have an element type.");
-
- if (Ty->isBuiltinType())
- return nullptr;
-
- CXXRecordDecl *TheRecordDecl = Ty->getAsCXXRecordDecl();
- assert(TheRecordDecl && "Resource should have a resource type declaration.");
- return TheRecordDecl;
-}
-
+// Returns handle type of a resource, if the VarDecl is a resource
+// or an array of resources
static const HLSLAttributedResourceType *
-findAttributedResourceTypeOnField(VarDecl *VD) {
+FindHandleTypeOnResource(const VarDecl *VD) {
+ // If VarDecl is a resource class, the first field must
+ // be the resource handle of type HLSLAttributedResourceType
assert(VD != nullptr && "expected VarDecl");
- if (RecordDecl *RD = getRecordDeclFromVarDecl(VD)) {
- for (auto *FD : RD->fields()) {
- if (const HLSLAttributedResourceType *AttrResType =
- dyn_cast<HLSLAttributedResourceType>(FD->getType().getTypePtr()))
- return AttrResType;
+ const Type *Ty = VD->getType()->getPointeeOrArrayElementType();
+ if (RecordDecl *RD = Ty->getAsCXXRecordDecl()) {
+ if (!RD->fields().empty()) {
+ const auto &FirstFD = RD->fields().begin();
+ return dyn_cast<HLSLAttributedResourceType>(
+ FirstFD->getType().getTypePtr());
}
}
return nullptr;
}
-// Iterate over RecordType fields and return true if any of them matched the
-// register type
-static bool ContainsResourceForRegisterType(Sema &S, const RecordType *RT,
- RegisterType RegType) {
+// Walks though the user defined record type, finds resource class
+// that matches the RegisterBinding.Type and assigns it to
+// RegisterBinding::Decl.
+static bool
+ProcessResourceBindingOnUserRecordDecl(const RecordType *RT,
+ HLSLResourceBindingAttr *RBA) {
+
llvm::SmallVector<const Type *> TypesToScan;
TypesToScan.emplace_back(RT);
+ RegisterType RegType = RBA->getRegisterType();
while (!TypesToScan.empty()) {
const Type *T = TypesToScan.pop_back_val();
- while (T->isArrayType())
+
+ while (T->isArrayType()) {
+ // FIXME: calculate the binding size from the array dimensions (or
+ // unbounded for unsized array) size *= (size_of_array);
T = T->getArrayElementTypeNoTypeQual();
+ }
if (T->isIntegralOrEnumerationType() || T->isFloatingType()) {
if (RegType == RegisterType::C)
return true;
@@ -1037,8 +1034,12 @@ static bool ContainsResourceForRegisterType(Sema &S, const RecordType *RT,
if (const HLSLAttributedResourceType *AttrResType =
dyn_cast<HLSLAttributedResourceType>(FieldTy)) {
ResourceClass RC = AttrResType->getAttrs().ResourceClass;
- if (getRegisterType(RC) == RegType)
+ if (getRegisterType(RC) == RegType) {
+ assert(RBA->getResourceField() == nullptr &&
+ "multiple register bindings of the same type are not allowed");
+ RBA->setResourceField(FD);
return true;
+ }
} else {
TypesToScan.emplace_back(FD->getType().getTypePtr());
}
@@ -1047,26 +1048,28 @@ static bool ContainsResourceForRegisterType(Sema &S, const RecordType *RT,
return false;
}
-static void CheckContainsResourceForRegisterType(Sema &S,
- SourceLocation &ArgLoc,
- Decl *D, RegisterType RegType,
- bool SpecifiedSpace) {
+// return false if the register binding is not valid
+static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc,
+ Decl *D, RegisterType RegType,
+ bool SpecifiedSpace) {
int RegTypeNum = static_cast<int>(RegType);
// check if the decl type is groupshared
if (D->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) {
S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
- return;
+ return false;
}
// Cbuffers and Tbuffers are HLSLBufferDecl types
if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) {
ResourceClass RC = CBufferOrTBuffer->isCBuffer() ? ResourceClass::CBuffer
: ResourceClass::SRV;
- if (RegType != getRegisterType(RC))
- S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
- << RegTypeNum;
- return;
+ if (RegType == getRegisterType(RC))
+ return true;
+
+ S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
+ << RegTypeNum;
+ return false;
}
// Samplers, UAVs, and SRVs are VarDecl types
@@ -1075,11 +1078,13 @@ static void CheckContainsResourceForRegisterType(Sema &S,
// Resource
if (const HLSLAttributedResourceType *AttrResType =
- findAttributedResourceTypeOnField(VD)) {
- if (RegType != getRegisterType(AttrResType->getAttrs().ResourceClass))
- S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
- << RegTypeNum;
- return;
+ FindHandleTypeOnResource(VD)) {
+ if (RegType == getRegisterType(AttrResType->getAttrs().ResourceClass))
+ return true;
+
+ S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
+ << RegTypeNum;
+ return false;
}
const clang::Type *Ty = VD->getType().getTypePtr();
@@ -1088,36 +1093,43 @@ static void CheckContainsResourceForRegisterType(Sema &S,
// Basic types
if (Ty->isArithmeticType()) {
+ bool IsValid = true;
bool DeclaredInCOrTBuffer = isa<HLSLBufferDecl>(D->getDeclContext());
- if (SpecifiedSpace && !DeclaredInCOrTBuffer)
+ if (SpecifiedSpace && !DeclaredInCOrTBuffer) {
S.Diag(ArgLoc, diag::err_hlsl_space_on_global_constant);
+ IsValid = false;
+ }
if (!DeclaredInCOrTBuffer &&
(Ty->isIntegralType(S.getASTContext()) || Ty->isFloatingType())) {
// Default Globals
if (RegType == RegisterType::CBuffer)
S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b);
- else if (RegType != RegisterType::C)
+ else if (RegType != RegisterType::C) {
S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
+ IsValid = false;
+ }
} else {
if (RegType == RegisterType::C)
S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset);
- else
+ else {
S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
+ IsValid = false;
+ }
}
- } else if (Ty->isRecordType()) {
- // Class/struct types - walk the declaration and check each field and
- // subclass
- if (!ContainsResourceForRegisterType(S, Ty->getAs<RecordType>(), RegType))
- S.Diag(D->getLocation(), diag::warn_hlsl_user_defined_type_missing_member)
- << RegTypeNum;
- } else {
- // Anything else is an error
- S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
+ return IsValid;
}
+ if (Ty->isRecordType())
+ // RecordTypes will be diagnosed in ProcessResourceBindingOnDecl
+ // that is called from ActOnVariableDeclarator
+ return true;
+
+ // Anything else is an error
+ S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
+ return false;
}
-static void ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
+static bool ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
RegisterType regType) {
// make sure that there are no two register annotations
// applied to the decl with the same register type
@@ -1135,21 +1147,19 @@ static void ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
RegisterType otherRegType = getRegisterType(attr->getSlot());
if (RegisterTypesDetected[static_cast<int>(otherRegType)]) {
- if (PreviousConflicts[TheDecl].count(otherRegType))
- continue;
int otherRegTypeNum = static_cast<int>(otherRegType);
S.Diag(TheDecl->getLocation(),
diag::err_hlsl_duplicate_register_annotation)
<< otherRegTypeNum;
- PreviousConflicts[TheDecl].insert(otherRegType);
- } else {
- RegisterTypesDetected[static_cast<int>(otherRegType)] = true;
+ return false;
}
+ RegisterTypesDetected[static_cast<int>(otherRegType)] = true;
}
}
+ return true;
}
-static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
+static bool DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
Decl *D, RegisterType RegType,
bool SpecifiedSpace) {
@@ -1159,10 +1169,11 @@ static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
"expecting VarDecl or HLSLBufferDecl");
// check if the declaration contains resource matching the register type
- CheckContainsResourceForRegisterType(S, ArgLoc, D, RegType, SpecifiedSpace);
+ if (!DiagnoseLocalRegisterBinding(S, ArgLoc, D, RegType, SpecifiedSpace))
+ return false;
// next, if multiple register annotations exist, check that none conflict.
- ValidateMultipleRegisterAnnotations(S, D, RegType);
+ return ValidateMultipleRegisterAnnotations(S, D, RegType);
}
void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {
@@ -1203,23 +1214,24 @@ void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {
Slot = Str;
}
- RegisterType regType;
+ RegisterType RegType;
+ unsigned SlotNum = 0;
+ unsigned SpaceNum = 0;
// Validate.
if (!Slot.empty()) {
- regType = getRegisterType(Slot);
- if (regType == RegisterType::I) {
+ RegType = getRegisterType(Slot);
+ if (RegType == RegisterType::I) {
Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_i);
return;
}
- if (regType == RegisterType::Invalid) {
+ if (RegType == RegisterType::Invalid) {
Diag(ArgLoc, diag::err_hlsl_binding_type_invalid) << Slot.substr(0, 1);
return;
}
- StringRef SlotNum = Slot.substr(1);
- unsigned Num = 0;
- if (SlotNum.getAsInteger(10, Num)) {
+ StringRef SlotNumStr = Slot.substr(1);
+ if (SlotNumStr.getAsInteger(10, SlotNum)) {
Diag(ArgLoc, diag::err_hlsl_unsupported_register_number);
return;
}
@@ -1229,20 +1241,22 @@ void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {
Diag(SpaceArgLoc, diag::err_hlsl_expected_space) << Space;
return;
}
- StringRef SpaceNum = Space.substr(5);
- unsigned Num = 0;
- if (SpaceNum.getAsInteger(10, Num)) {
+ StringRef SpaceNumStr = Space.substr(5);
+ if (SpaceNumStr.getAsInteger(10, SpaceNum)) {
Diag(SpaceArgLoc, diag::err_hlsl_expected_space) << Space;
return;
}
- DiagnoseHLSLRegisterAttribute(SemaRef, ArgLoc, TheDecl, regType,
- SpecifiedSpace);
+ if (!DiagnoseHLSLRegisterAttribute(SemaRef, ArgLoc, TheDecl, RegType,
+ SpecifiedSpace))
+ return;
HLSLResourceBindingAttr *NewAttr =
HLSLResourceBindingAttr::Create(getASTContext(), Slot, Space, AL);
- if (NewAttr)
+ if (NewAttr) {
+ NewAttr->setBinding(RegType, SlotNum, SpaceNum);
TheDecl->addAttr(NewAttr);
+ }
}
void SemaHLSL::handleParamModifierAttr(Decl *D, const ParsedAttr &AL) {
@@ -2228,3 +2242,62 @@ QualType SemaHLSL::getInoutParameterType(QualType Ty) {
Ty.addRestrict();
return Ty;
}
+
+// Walks though existing explicit bindings, finds the actual resource class
+// decl the binding applies to and sets it to attr->ResourceField.
+// Additional processing of resource binding can be added here later on,
+// such as preparation for overapping resource detection or implicit binding.
+void SemaHLSL::ProcessResourceBindingOnDecl(VarDecl *D) {
+ if (!D->hasGlobalStorage())
+ return;
+
+ for (Attr *A : D->attrs()) {
+ HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(A);
+ if (!RBA)
+ continue;
+
+ // // Cbuffers and Tbuffers are HLSLBufferDecl types
+ if (const HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) {
+ assert(RBA->getRegisterType() ==
+ getRegisterType(CBufferOrTBuffer->isCBuffer()
+ ? ResourceClass::CBuffer
+ : ResourceClass::SRV) &&
+ "this should have been handled in DiagnoseLocalRegisterBinding");
+ // should we handle HLSLBufferDecl here?
+ continue;
+ }
+
+ // Samplers, UAVs, and SRVs are VarDecl types
+ assert(isa<VarDecl>(D) && "D is expected to be VarDecl or HLSLBufferDecl");
+ const VarDecl *VD = cast<VarDecl>(D);
+
+ // Register binding directly on global resource class variable
+ if (const HLSLAttributedResourceType *AttrResType =
+ FindHandleTypeOnResource(VD)) {
+ // FIXME: if array, calculate the binding size from the array dimensions
+ // (or unbounded for unsized array)
+ assert(RBA->getResourceField() == nullptr);
+ continue;
+ }
+
+ // Global array
+ const clang::Type *Ty = VD->getType().getTypePtr();
+ while (Ty->isArrayType()) {
+ Ty = Ty->getArrayElementTypeNoTypeQual();
+ }
+
+ // Basic types
+ if (Ty->isArithmeticType()) {
+ continue;
+ }
+
+ if (Ty->isRecordType()) {
+ if (!ProcessResourceBindingOnUserRecordDecl(Ty->getAs<RecordType>(),
+ RBA)) {
+ SemaRef.Diag(D->getLocation(),
+ diag::warn_hlsl_user_defined_type_missing_member)
+ << static_cast<int>(RBA->getRegisterType());
+ }
+ }
+ }
+}
diff --git a/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl b/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl
index ea2d576e4cca55..40517f393e1284 100644
--- a/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl
+++ b/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl
@@ -106,7 +106,6 @@ struct Eg12{
MySRV s1;
MySRV s2;
};
-// expected-warning at +3{{binding type 'u' only applies to types containing UAV resources}}
// expected-warning at +2{{binding type 'u' only applies to types containing UAV resources}}
// expected-error at +1{{binding type 'u' cannot be applied more than once}}
Eg12 e12 : register(u9) : register(u10);
@@ -115,12 +114,14 @@ struct Eg13{
MySRV s1;
MySRV s2;
};
-// expected-warning at +4{{binding type 'u' only applies to types containing UAV resources}}
// expected-warning at +3{{binding type 'u' only applies to types containing UAV resources}}
-// expected-warning at +2{{binding type 'u' only applies to types containing UAV resources}}
+// expected-error at +2{{binding type 'u' cannot be applied more than once}}
// expected-error at +1{{binding type 'u' cannot be applied more than once}}
Eg13 e13 : register(u9) : register(u10) : register(u11);
+// expected-error at +1{{binding type 't' cannot be applied more than once}}
+Eg13 e13_2 : register(t11) : register(t12);
+
struct Eg14{
MyTemplatedUAV<int> r1;
};
@@ -132,4 +133,3 @@ struct Eg15 {
};
// expected no error
Eg15 e15 : register(c0);
-
>From a6c06943ce5df79e6765e12874c96c907b20d030 Mon Sep 17 00:00:00 2001
From: Helena Kotas <hekotas at microsoft.com>
Date: Fri, 4 Oct 2024 13:52:47 -0700
Subject: [PATCH 2/5] clang-format
---
clang/lib/Sema/SemaHLSL.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 568a8de30c1fc5..5c27a74a853bba 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2250,7 +2250,7 @@ QualType SemaHLSL::getInoutParameterType(QualType Ty) {
void SemaHLSL::ProcessResourceBindingOnDecl(VarDecl *D) {
if (!D->hasGlobalStorage())
return;
-
+
for (Attr *A : D->attrs()) {
HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(A);
if (!RBA)
>From a6a52327bef4325a00a2b8a1715b8b5b1315994f Mon Sep 17 00:00:00 2001
From: Helena Kotas <hekotas at microsoft.com>
Date: Wed, 9 Oct 2024 16:34:06 -0700
Subject: [PATCH 3/5] Collect all resource binding requirements and analyze
explicit bindings based on that Also adds bindings size calculation and
removed ResourceDecl field from HLSLResourceBindingAttr.
---
clang/include/clang/Basic/Attr.td | 25 ++-
clang/include/clang/Sema/SemaHLSL.h | 59 +++++-
clang/lib/Sema/SemaDecl.cpp | 2 +-
clang/lib/Sema/SemaHLSL.cpp | 276 ++++++++++++++++++----------
4 files changed, 256 insertions(+), 106 deletions(-)
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index 668c599da81390..3997ffe78fbf96 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4591,22 +4591,20 @@ def HLSLResourceBinding: InheritableAttr {
let AdditionalMembers = [{
enum class RegisterType : unsigned { SRV, UAV, CBuffer, Sampler, C, I, Invalid };
- const FieldDecl *ResourceField = nullptr;
RegisterType RegType;
unsigned SlotNumber;
unsigned SpaceNumber;
+
+ // Size of the binding
+ // 0 == not set
+ //-1 == unbounded
+ int Size;
- void setBinding(RegisterType RT, unsigned SlotNum, unsigned SpaceNum) {
+ void setBinding(RegisterType RT, unsigned SlotNum, unsigned SpaceNum, int Size = 0) {
RegType = RT;
SlotNumber = SlotNum;
SpaceNumber = SpaceNum;
}
- void setResourceField(const FieldDecl *FD) {
- ResourceField = FD;
- }
- const FieldDecl *getResourceField() {
- return ResourceField;
- }
RegisterType getRegisterType() {
return RegType;
}
@@ -4616,6 +4614,17 @@ def HLSLResourceBinding: InheritableAttr {
unsigned getSpaceNumber() {
return SpaceNumber;
}
+ unsigned getSize() {
+ assert(Size == -1 || Size > 0 && "size not set");
+ return Size;
+ }
+ void setSize(int N) {
+ assert(N == -1 || N > 0 && "unexpected size value");
+ Size = N;
+ }
+ bool isSizeUnbounded() {
+ return Size == -1;
+ }
}];
}
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index 018e7ea5901a2b..ce262fd41dff37 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -30,12 +30,60 @@ class ParsedAttr;
class Scope;
class VarDecl;
+using llvm::dxil::ResourceClass;
+
// FIXME: This can be hidden (as static function in SemaHLSL.cpp) once we no
// longer need to create builtin buffer types in HLSLExternalSemaSource.
bool CreateHLSLAttributedResourceType(
Sema &S, QualType Wrapped, ArrayRef<const Attr *> AttrList,
QualType &ResType, HLSLAttributedResourceLocInfo *LocInfo = nullptr);
+enum class BindingType : uint8_t { NotAssigned, Explicit, Implicit };
+
+// DeclBindingInfo struct stores information about required/assigned resource
+// binding onon a declaration for specific resource class.
+struct DeclBindingInfo {
+ const VarDecl *Decl;
+ ResourceClass ResClass;
+ int Size; // -1 == unbounded array
+ const HLSLResourceBindingAttr *Attr;
+ BindingType BindType;
+
+ DeclBindingInfo(const VarDecl *Decl, ResourceClass ResClass, int Size = 0,
+ BindingType BindType = BindingType::NotAssigned,
+ const HLSLResourceBindingAttr *Attr = nullptr)
+ : Decl(Decl), ResClass(ResClass), Size(Size), Attr(Attr),
+ BindType(BindType) {}
+
+ void setBindingAttribute(HLSLResourceBindingAttr *A, BindingType BT) {
+ assert(Attr == nullptr && BindType == BindingType::NotAssigned &&
+ "binding attribute already assigned");
+ Attr = A;
+ BindType = BT;
+ }
+};
+
+// ResourceBindings class stores information about all resource bindings
+// in a shader. It is used for binding diagnostics and implicit binding
+// assigments.
+class ResourceBindings {
+public:
+ DeclBindingInfo *addDeclBindingInfo(const VarDecl *VD, ResourceClass ResClass,
+ int Size);
+ DeclBindingInfo *getDeclBindingInfo(const VarDecl *VD,
+ ResourceClass ResClass);
+ bool hasBindingInfoForDecl(const VarDecl *VD);
+
+private:
+ // List of all resource bindings required by the shader.
+ // A global declaration can have multiple bindings for different
+ // resource classes. They are all stored sequentially in this list.
+ // The DeclToBindingListIndex hashtable maps a declaration to the
+ // index of the first binding info in the list.
+ llvm::SmallVector<DeclBindingInfo> BindingsList;
+ llvm::DenseMap<const VarDecl *, unsigned> DeclToBindingListIndex;
+};
+
class SemaHLSL : public SemaBase {
public:
SemaHLSL(Sema &S);
@@ -56,6 +104,7 @@ class SemaHLSL : public SemaBase {
mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
HLSLParamModifierAttr::Spelling Spelling);
void ActOnTopLevelFunction(FunctionDecl *FD);
+ void ActOnVariableDeclarator(VarDecl *VD);
void CheckEntryPoint(FunctionDecl *FD);
void CheckSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param,
const HLSLAnnotationAttr *AnnotationAttr);
@@ -63,7 +112,6 @@ class SemaHLSL : public SemaBase {
const Attr *A, llvm::Triple::EnvironmentType Stage,
std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages);
void DiagnoseAvailabilityViolations(TranslationUnitDecl *TU);
- void ProcessResourceBindingOnDecl(VarDecl *D);
QualType handleVectorBinOpConversion(ExprResult &LHS, ExprResult &RHS,
QualType LHSType, QualType RHSType,
@@ -104,6 +152,15 @@ class SemaHLSL : public SemaBase {
llvm::DenseMap<const HLSLAttributedResourceType *,
HLSLAttributedResourceLocInfo>
LocsForHLSLAttributedResources;
+
+ // List of all resource bindings
+ ResourceBindings Bindings;
+
+private:
+ void FindResourcesOnVarDecl(VarDecl *D);
+ void FindResourcesOnUserRecordDecl(const VarDecl *VD, const RecordType *RT,
+ int Size);
+ void ProcessExplicitBindingsOnDecl(VarDecl *D);
};
} // namespace clang
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index 8e27a5e068e702..770d00710a6816 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -7877,7 +7877,7 @@ NamedDecl *Sema::ActOnVariableDeclarator(
ProcessDeclAttributes(S, NewVD, D);
if (getLangOpts().HLSL)
- HLSL().ProcessResourceBindingOnDecl(NewVD);
+ HLSL().ActOnVariableDeclarator(NewVD);
// FIXME: This is probably the wrong location to be doing this and we should
// probably be doing this for more attributes (especially for function
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 5c27a74a853bba..197ee63c07deeb 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -40,8 +40,8 @@
#include <utility>
using namespace clang;
-using llvm::dxil::ResourceClass;
using RegisterType = HLSLResourceBindingAttr::RegisterType;
+
static RegisterType getRegisterType(ResourceClass RC) {
switch (RC) {
case ResourceClass::SRV:
@@ -81,6 +81,49 @@ static RegisterType getRegisterType(StringRef Slot) {
}
}
+static ResourceClass getResourceClass(RegisterType RT) {
+ switch (RT) {
+ case RegisterType::SRV:
+ return ResourceClass::SRV;
+ case RegisterType::UAV:
+ return ResourceClass::UAV;
+ case RegisterType::CBuffer:
+ return ResourceClass::CBuffer;
+ case RegisterType::Sampler:
+ return ResourceClass::Sampler;
+ default:
+ llvm_unreachable("unexpected RegisterType value");
+ }
+}
+
+DeclBindingInfo *ResourceBindings::addDeclBindingInfo(const VarDecl *VD,
+ ResourceClass ResClass,
+ int Size) {
+ assert(getDeclBindingInfo(VD, ResClass) == nullptr &&
+ "DeclBindingInfo already added");
+ if (DeclToBindingListIndex.find(VD) == DeclToBindingListIndex.end())
+ DeclToBindingListIndex[VD] = BindingsList.size();
+ return &BindingsList.emplace_back(DeclBindingInfo(VD, ResClass, Size));
+}
+
+DeclBindingInfo *ResourceBindings::getDeclBindingInfo(const VarDecl *VD,
+ ResourceClass ResClass) {
+ auto Entry = DeclToBindingListIndex.find(VD);
+ if (Entry != DeclToBindingListIndex.end()) {
+ unsigned Index = Entry->getSecond();
+ while (Index < BindingsList.size() && BindingsList[Index].Decl == VD) {
+ if (BindingsList[Index].ResClass == ResClass)
+ return &BindingsList[Index];
+ Index++;
+ }
+ }
+ return nullptr;
+}
+
+bool ResourceBindings::hasBindingInfoForDecl(const VarDecl *VD) {
+ return DeclToBindingListIndex.contains(VD);
+}
+
SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S) {}
Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer,
@@ -983,14 +1026,11 @@ SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) {
return LocInfo;
}
-// Returns handle type of a resource, if the VarDecl is a resource
-// or an array of resources
+// Returns handle type of a resource, if the type is a resource
static const HLSLAttributedResourceType *
-FindHandleTypeOnResource(const VarDecl *VD) {
- // If VarDecl is a resource class, the first field must
+FindHandleTypeOnResource(const Type *Ty) {
+ // If Ty is a resource class, the first field must
// be the resource handle of type HLSLAttributedResourceType
- assert(VD != nullptr && "expected VarDecl");
- const Type *Ty = VD->getType()->getPointeeOrArrayElementType();
if (RecordDecl *RD = Ty->getAsCXXRecordDecl()) {
if (!RD->fields().empty()) {
const auto &FirstFD = RD->fields().begin();
@@ -1001,51 +1041,53 @@ FindHandleTypeOnResource(const VarDecl *VD) {
return nullptr;
}
-// Walks though the user defined record type, finds resource class
-// that matches the RegisterBinding.Type and assigns it to
-// RegisterBinding::Decl.
-static bool
-ProcessResourceBindingOnUserRecordDecl(const RecordType *RT,
- HLSLResourceBindingAttr *RBA) {
-
- llvm::SmallVector<const Type *> TypesToScan;
- TypesToScan.emplace_back(RT);
- RegisterType RegType = RBA->getRegisterType();
-
- while (!TypesToScan.empty()) {
- const Type *T = TypesToScan.pop_back_val();
-
- while (T->isArrayType()) {
- // FIXME: calculate the binding size from the array dimensions (or
- // unbounded for unsized array) size *= (size_of_array);
- T = T->getArrayElementTypeNoTypeQual();
- }
- if (T->isIntegralOrEnumerationType() || T->isFloatingType()) {
- if (RegType == RegisterType::C)
- return true;
+// Returns handle type of a resource, if the VarDecl is a resource
+static const HLSLAttributedResourceType *
+FindHandleTypeOnResource(const VarDecl *VD) {
+ assert(VD != nullptr && "expected VarDecl");
+ return FindHandleTypeOnResource(VD->getType().getTypePtr());
+}
+
+// Walks though the global variable declaration, collects all resource binding
+// requirements and adds them to Bindings
+void SemaHLSL::FindResourcesOnUserRecordDecl(const VarDecl *VD,
+ const RecordType *RT, int Size) {
+ const RecordDecl *RD = RT->getDecl();
+ for (FieldDecl *FD : RD->fields()) {
+ const Type *Ty = FD->getType()->getUnqualifiedDesugaredType();
+
+ // Calculate array size and unwrap
+ int ArraySize = 1;
+ assert(!Ty->isIncompleteArrayType() &&
+ "incomplete arrays inside user defined types are not supported");
+ while (Ty->isConstantArrayType()) {
+ const ConstantArrayType *CAT = cast<ConstantArrayType>(Ty);
+ ArraySize *= CAT->getSize().getSExtValue();
+ Ty = CAT->getElementType()->getUnqualifiedDesugaredType();
}
- const RecordType *RT = T->getAs<RecordType>();
- if (!RT)
+
+ if (!Ty->isRecordType())
continue;
- const RecordDecl *RD = RT->getDecl();
- for (FieldDecl *FD : RD->fields()) {
- const Type *FieldTy = FD->getType().getTypePtr();
- if (const HLSLAttributedResourceType *AttrResType =
- dyn_cast<HLSLAttributedResourceType>(FieldTy)) {
- ResourceClass RC = AttrResType->getAttrs().ResourceClass;
- if (getRegisterType(RC) == RegType) {
- assert(RBA->getResourceField() == nullptr &&
- "multiple register bindings of the same type are not allowed");
- RBA->setResourceField(FD);
- return true;
- }
- } else {
- TypesToScan.emplace_back(FD->getType().getTypePtr());
- }
+ // Field is a resource or array of resources
+ if (const HLSLAttributedResourceType *AttrResType =
+ FindHandleTypeOnResource(Ty)) {
+ ResourceClass RC = AttrResType->getAttrs().ResourceClass;
+
+ // Add a new DeclBindingInfo to Bindings. Update the binding size if
+ // a binding info already exists (there are multiple resources of same
+ // resource class in this user decl)
+ if (auto *DBI = Bindings.getDeclBindingInfo(VD, RC))
+ DBI->Size += Size * ArraySize;
+ else
+ Bindings.addDeclBindingInfo(VD, RC, Size);
+ } else if (const RecordType *RT = dyn_cast<RecordType>(Ty)) {
+ // Recursively scan embedded struct or class; it would be nice to do this
+ // without recursion, but tricky to corrently calculate the size.
+ // Hopefully nesting of structs in structs too many levels is unlikely.
+ FindResourcesOnUserRecordDecl(VD, RT, Size);
}
}
- return false;
}
// return false if the register binding is not valid
@@ -1093,11 +1135,9 @@ static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc,
// Basic types
if (Ty->isArithmeticType()) {
- bool IsValid = true;
bool DeclaredInCOrTBuffer = isa<HLSLBufferDecl>(D->getDeclContext());
if (SpecifiedSpace && !DeclaredInCOrTBuffer) {
S.Diag(ArgLoc, diag::err_hlsl_space_on_global_constant);
- IsValid = false;
}
if (!DeclaredInCOrTBuffer &&
@@ -1107,17 +1147,15 @@ static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc,
S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b);
else if (RegType != RegisterType::C) {
S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
- IsValid = false;
}
} else {
if (RegType == RegisterType::C)
S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset);
else {
S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
- IsValid = false;
}
}
- return IsValid;
+ return false;
}
if (Ty->isRecordType())
// RecordTypes will be diagnosed in ProcessResourceBindingOnDecl
@@ -2057,6 +2095,7 @@ bool SemaHLSL::IsIntangibleType(clang::QualType QT) {
CXXRecordDecl *RD = RT->getAsCXXRecordDecl();
assert(RD != nullptr &&
"all HLSL struct and classes should be CXXRecordDecl");
+ assert(RD->isCompleteDefinition() && "expecting complete type");
return RD->isHLSLIntangible();
}
@@ -2243,61 +2282,106 @@ QualType SemaHLSL::getInoutParameterType(QualType Ty) {
return Ty;
}
-// Walks though existing explicit bindings, finds the actual resource class
-// decl the binding applies to and sets it to attr->ResourceField.
-// Additional processing of resource binding can be added here later on,
-// such as preparation for overapping resource detection or implicit binding.
-void SemaHLSL::ProcessResourceBindingOnDecl(VarDecl *D) {
- if (!D->hasGlobalStorage())
+void SemaHLSL::ActOnVariableDeclarator(VarDecl *VD) {
+ if (VD->hasGlobalStorage()) {
+ // make sure the declaration has a complete type
+ if (SemaRef.RequireCompleteType(
+ VD->getLocation(),
+ SemaRef.getASTContext().getBaseElementType(VD->getType()),
+ diag::err_typecheck_decl_incomplete_type)) {
+ VD->setInvalidDecl();
+ return;
+ }
+
+ // find all resources on decl
+ if (IsIntangibleType(VD->getType()))
+ FindResourcesOnVarDecl(VD);
+
+ // process explicit bindings
+ ProcessExplicitBindingsOnDecl(VD);
+ }
+}
+
+// Walks though the global variable declaration, collects all resource binding
+// requirements and adds them to Bindings
+void SemaHLSL::FindResourcesOnVarDecl(VarDecl *VD) {
+ assert(VD->hasGlobalStorage() && IsIntangibleType(VD->getType()) &&
+ "expected global variable that contains HLSL resource");
+
+ // Cbuffers and Tbuffers are HLSLBufferDecl types
+ if (const HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(VD)) {
+ Bindings.addDeclBindingInfo(VD,
+ CBufferOrTBuffer->isCBuffer()
+ ? ResourceClass::CBuffer
+ : ResourceClass::SRV,
+ 1);
return;
+ }
- for (Attr *A : D->attrs()) {
- HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(A);
- if (!RBA)
- continue;
+ // Calculate size of array and unwrap
+ int Size = 1;
+ const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
+ if (Ty->isIncompleteArrayType())
+ Size = -1;
+ while (Ty->isConstantArrayType()) {
+ const ConstantArrayType *CAT = cast<ConstantArrayType>(Ty);
+ Size *= CAT->getSize().getSExtValue();
+ Ty = CAT->getElementType()->getUnqualifiedDesugaredType();
+ }
- // // Cbuffers and Tbuffers are HLSLBufferDecl types
- if (const HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) {
- assert(RBA->getRegisterType() ==
- getRegisterType(CBufferOrTBuffer->isCBuffer()
- ? ResourceClass::CBuffer
- : ResourceClass::SRV) &&
- "this should have been handled in DiagnoseLocalRegisterBinding");
- // should we handle HLSLBufferDecl here?
- continue;
- }
+ // Resource (or array of resources)
+ if (const HLSLAttributedResourceType *AttrResType =
+ FindHandleTypeOnResource(Ty)) {
+ Bindings.addDeclBindingInfo(VD, AttrResType->getAttrs().ResourceClass,
+ Size);
+ return;
+ }
- // Samplers, UAVs, and SRVs are VarDecl types
- assert(isa<VarDecl>(D) && "D is expected to be VarDecl or HLSLBufferDecl");
- const VarDecl *VD = cast<VarDecl>(D);
+ assert(Size != -1 &&
+ "unbounded arrays of user defined types are not supported");
- // Register binding directly on global resource class variable
- if (const HLSLAttributedResourceType *AttrResType =
- FindHandleTypeOnResource(VD)) {
- // FIXME: if array, calculate the binding size from the array dimensions
- // (or unbounded for unsized array)
- assert(RBA->getResourceField() == nullptr);
+ // User defined record type
+ if (const RecordType *RT = dyn_cast<RecordType>(Ty))
+ FindResourcesOnUserRecordDecl(VD, RT, Size);
+}
+
+// Walks though the explicit resource binding attributes on the declaration,
+// and makes sure there is a resource that matched the binding and updates
+// DeclBindingInfoLists
+void SemaHLSL::ProcessExplicitBindingsOnDecl(VarDecl *VD) {
+ assert(VD->hasGlobalStorage() && "expected global variable");
+
+ for (Attr *A : VD->attrs()) {
+ HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(A);
+ if (!RBA)
continue;
- }
- // Global array
- const clang::Type *Ty = VD->getType().getTypePtr();
- while (Ty->isArrayType()) {
- Ty = Ty->getArrayElementTypeNoTypeQual();
- }
+ RegisterType RT = RBA->getRegisterType();
+ assert(RT != RegisterType::I && RT != RegisterType::Invalid &&
+ "invalid or obsolete register type should never have an attribute "
+ "created");
- // Basic types
- if (Ty->isArithmeticType()) {
+ // These were already diagnosed earlier
+ if (RT == RegisterType::C) {
+ if (Bindings.hasBindingInfoForDecl(VD))
+ SemaRef.Diag(VD->getLocation(),
+ diag::warn_hlsl_user_defined_type_missing_member)
+ << static_cast<int>(RT);
continue;
}
- if (Ty->isRecordType()) {
- if (!ProcessResourceBindingOnUserRecordDecl(Ty->getAs<RecordType>(),
- RBA)) {
- SemaRef.Diag(D->getLocation(),
- diag::warn_hlsl_user_defined_type_missing_member)
- << static_cast<int>(RBA->getRegisterType());
- }
+ // Find DeclBindingInfo for this binding and update it, or report error
+ // if it does not exist (user type does to contain resources with the
+ // expected resource class).
+ ResourceClass RC = getResourceClass(RT);
+ if (DeclBindingInfo *BI = Bindings.getDeclBindingInfo(VD, RC)) {
+ // update binding info
+ RBA->setSize(BI->Size);
+ BI->setBindingAttribute(RBA, BindingType::Explicit);
+ } else {
+ SemaRef.Diag(VD->getLocation(),
+ diag::warn_hlsl_user_defined_type_missing_member)
+ << static_cast<int>(RT);
}
}
}
>From aa6247f414b2bd3d39f349646f3a97ec72d5d517 Mon Sep 17 00:00:00 2001
From: Helena Kotas <hekotas at microsoft.com>
Date: Wed, 9 Oct 2024 17:08:25 -0700
Subject: [PATCH 4/5] removed unused variable, cleanup
---
clang/lib/Sema/SemaHLSL.cpp | 14 +++-----------
1 file changed, 3 insertions(+), 11 deletions(-)
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 197ee63c07deeb..0423340ee5fc4f 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1136,24 +1136,21 @@ static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc,
// Basic types
if (Ty->isArithmeticType()) {
bool DeclaredInCOrTBuffer = isa<HLSLBufferDecl>(D->getDeclContext());
- if (SpecifiedSpace && !DeclaredInCOrTBuffer) {
+ if (SpecifiedSpace && !DeclaredInCOrTBuffer)
S.Diag(ArgLoc, diag::err_hlsl_space_on_global_constant);
- }
if (!DeclaredInCOrTBuffer &&
(Ty->isIntegralType(S.getASTContext()) || Ty->isFloatingType())) {
// Default Globals
if (RegType == RegisterType::CBuffer)
S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b);
- else if (RegType != RegisterType::C) {
+ else if (RegType != RegisterType::C)
S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
- }
} else {
if (RegType == RegisterType::C)
S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset);
- else {
+ else
S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
- }
}
return false;
}
@@ -1172,13 +1169,8 @@ static bool ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
// make sure that there are no two register annotations
// applied to the decl with the same register type
bool RegisterTypesDetected[5] = {false};
-
RegisterTypesDetected[static_cast<int>(regType)] = true;
- // we need a static map to keep track of previous conflicts
- // so that we don't emit the same error multiple times
- static std::map<Decl *, std::set<RegisterType>> PreviousConflicts;
-
for (auto it = TheDecl->attr_begin(); it != TheDecl->attr_end(); ++it) {
if (HLSLResourceBindingAttr *attr =
dyn_cast<HLSLResourceBindingAttr>(*it)) {
>From a6edabe43eefc2957932498ee35b71e800af9fdd Mon Sep 17 00:00:00 2001
From: Helena Kotas <hekotas at microsoft.com>
Date: Tue, 15 Oct 2024 21:14:19 -0700
Subject: [PATCH 5/5] Code review feedback
- remove size calculation and storage - it is currently not used or tested
- remove invalid register type
- set fields on HLSLResourceBindingAttr as private and accessors public, add const
- update function names
- update comments
- use more effective SmallVector and DenseMap methods
---
clang/include/clang/Basic/Attr.td | 31 +++----
clang/include/clang/Sema/SemaHLSL.h | 16 ++--
clang/lib/Sema/SemaHLSL.cpp | 121 ++++++++++++++--------------
3 files changed, 75 insertions(+), 93 deletions(-)
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index 916757ccbe2d47..0259b6e40ca962 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4594,42 +4594,29 @@ def HLSLResourceBinding: InheritableAttr {
let Args = [StringArgument<"Slot">, StringArgument<"Space", 1>];
let Documentation = [HLSLResourceBindingDocs];
let AdditionalMembers = [{
- enum class RegisterType : unsigned { SRV, UAV, CBuffer, Sampler, C, I, Invalid };
-
+ public:
+ enum class RegisterType : unsigned { SRV, UAV, CBuffer, Sampler, C, I };
+
+ private:
RegisterType RegType;
unsigned SlotNumber;
unsigned SpaceNumber;
-
- // Size of the binding
- // 0 == not set
- //-1 == unbounded
- int Size;
- void setBinding(RegisterType RT, unsigned SlotNum, unsigned SpaceNum, int Size = 0) {
+ public:
+ void setBinding(RegisterType RT, unsigned SlotNum, unsigned SpaceNum) {
RegType = RT;
SlotNumber = SlotNum;
SpaceNumber = SpaceNum;
}
- RegisterType getRegisterType() {
+ RegisterType getRegisterType() const {
return RegType;
}
- unsigned getSlotNumber() {
+ unsigned getSlotNumber() const {
return SlotNumber;
}
- unsigned getSpaceNumber() {
+ unsigned getSpaceNumber() const {
return SpaceNumber;
}
- unsigned getSize() {
- assert(Size == -1 || Size > 0 && "size not set");
- return Size;
- }
- void setSize(int N) {
- assert(N == -1 || N > 0 && "unexpected size value");
- Size = N;
- }
- bool isSizeUnbounded() {
- return Size == -1;
- }
}];
}
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index ce262fd41dff37..5eda4d544a5ae5 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -45,15 +45,13 @@ enum class BindingType : uint8_t { NotAssigned, Explicit, Implicit };
struct DeclBindingInfo {
const VarDecl *Decl;
ResourceClass ResClass;
- int Size; // -1 == unbounded array
const HLSLResourceBindingAttr *Attr;
BindingType BindType;
DeclBindingInfo(const VarDecl *Decl, ResourceClass ResClass, int Size = 0,
BindingType BindType = BindingType::NotAssigned,
const HLSLResourceBindingAttr *Attr = nullptr)
- : Decl(Decl), ResClass(ResClass), Size(Size), Attr(Attr),
- BindType(BindType) {}
+ : Decl(Decl), ResClass(ResClass), Attr(Attr), BindType(BindType) {}
void setBindingAttribute(HLSLResourceBindingAttr *A, BindingType BT) {
assert(Attr == nullptr && BindType == BindingType::NotAssigned &&
@@ -68,8 +66,8 @@ struct DeclBindingInfo {
// assigments.
class ResourceBindings {
public:
- DeclBindingInfo *addDeclBindingInfo(const VarDecl *VD, ResourceClass ResClass,
- int Size);
+ DeclBindingInfo *addDeclBindingInfo(const VarDecl *VD,
+ ResourceClass ResClass);
DeclBindingInfo *getDeclBindingInfo(const VarDecl *VD,
ResourceClass ResClass);
bool hasBindingInfoForDecl(const VarDecl *VD);
@@ -157,10 +155,10 @@ class SemaHLSL : public SemaBase {
ResourceBindings Bindings;
private:
- void FindResourcesOnVarDecl(VarDecl *D);
- void FindResourcesOnUserRecordDecl(const VarDecl *VD, const RecordType *RT,
- int Size);
- void ProcessExplicitBindingsOnDecl(VarDecl *D);
+ void collectResourcesOnVarDecl(VarDecl *D);
+ void collectResourcesOnUserRecordDecl(const VarDecl *VD,
+ const RecordType *RT);
+ void processExplicitBindingsOnDecl(VarDecl *D);
};
} // namespace clang
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 668d3ad9ecd6ba..a58c4281eeb375 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -56,28 +56,37 @@ static RegisterType getRegisterType(ResourceClass RC) {
llvm_unreachable("unexpected ResourceClass value");
}
-static RegisterType getRegisterType(StringRef Slot) {
+// Converts the first letter of string Slot to RegisterType.
+// Returns false if the letter does not correspond to a valid register type.
+static bool convertToRegisterType(StringRef Slot, RegisterType *RT) {
+ assert(RT != nullptr);
switch (Slot[0]) {
case 't':
case 'T':
- return RegisterType::SRV;
+ *RT = RegisterType::SRV;
+ return true;
case 'u':
case 'U':
- return RegisterType::UAV;
+ *RT = RegisterType::UAV;
+ return true;
case 'b':
case 'B':
- return RegisterType::CBuffer;
+ *RT = RegisterType::CBuffer;
+ return true;
case 's':
case 'S':
- return RegisterType::Sampler;
+ *RT = RegisterType::Sampler;
+ return true;
case 'c':
case 'C':
- return RegisterType::C;
+ *RT = RegisterType::C;
+ return true;
case 'i':
case 'I':
- return RegisterType::I;
+ *RT = RegisterType::I;
+ return true;
default:
- return RegisterType::Invalid;
+ return false;
}
}
@@ -91,19 +100,18 @@ static ResourceClass getResourceClass(RegisterType RT) {
return ResourceClass::CBuffer;
case RegisterType::Sampler:
return ResourceClass::Sampler;
- default:
+ case RegisterType::C:
+ case RegisterType::I:
llvm_unreachable("unexpected RegisterType value");
}
}
DeclBindingInfo *ResourceBindings::addDeclBindingInfo(const VarDecl *VD,
- ResourceClass ResClass,
- int Size) {
+ ResourceClass ResClass) {
assert(getDeclBindingInfo(VD, ResClass) == nullptr &&
"DeclBindingInfo already added");
- if (DeclToBindingListIndex.find(VD) == DeclToBindingListIndex.end())
- DeclToBindingListIndex[VD] = BindingsList.size();
- return &BindingsList.emplace_back(DeclBindingInfo(VD, ResClass, Size));
+ DeclToBindingListIndex.try_emplace(VD, BindingsList.size());
+ return &BindingsList.emplace_back(VD, ResClass);
}
DeclBindingInfo *ResourceBindings::getDeclBindingInfo(const VarDecl *VD,
@@ -1050,19 +1058,18 @@ FindHandleTypeOnResource(const VarDecl *VD) {
// Walks though the global variable declaration, collects all resource binding
// requirements and adds them to Bindings
-void SemaHLSL::FindResourcesOnUserRecordDecl(const VarDecl *VD,
- const RecordType *RT, int Size) {
+void SemaHLSL::collectResourcesOnUserRecordDecl(const VarDecl *VD,
+ const RecordType *RT) {
const RecordDecl *RD = RT->getDecl();
for (FieldDecl *FD : RD->fields()) {
const Type *Ty = FD->getType()->getUnqualifiedDesugaredType();
- // Calculate array size and unwrap
- int ArraySize = 1;
+ // Unwrap arrays
+ // FIXME: Calculate array size while unwrapping
assert(!Ty->isIncompleteArrayType() &&
"incomplete arrays inside user defined types are not supported");
while (Ty->isConstantArrayType()) {
const ConstantArrayType *CAT = cast<ConstantArrayType>(Ty);
- ArraySize *= CAT->getSize().getSExtValue();
Ty = CAT->getElementType()->getUnqualifiedDesugaredType();
}
@@ -1074,23 +1081,26 @@ void SemaHLSL::FindResourcesOnUserRecordDecl(const VarDecl *VD,
FindHandleTypeOnResource(Ty)) {
ResourceClass RC = AttrResType->getAttrs().ResourceClass;
- // Add a new DeclBindingInfo to Bindings. Update the binding size if
- // a binding info already exists (there are multiple resources of same
- // resource class in this user decl)
- if (auto *DBI = Bindings.getDeclBindingInfo(VD, RC))
- DBI->Size += Size * ArraySize;
- else
- Bindings.addDeclBindingInfo(VD, RC, Size);
+ // Add a new DeclBindingInfo to Bindings if it does not already exist
+ DeclBindingInfo *DBI = Bindings.getDeclBindingInfo(VD, RC);
+ if (!DBI)
+ Bindings.addDeclBindingInfo(VD, RC);
} else if (const RecordType *RT = dyn_cast<RecordType>(Ty)) {
// Recursively scan embedded struct or class; it would be nice to do this
- // without recursion, but tricky to corrently calculate the size.
- // Hopefully nesting of structs in structs too many levels is unlikely.
- FindResourcesOnUserRecordDecl(VD, RT, Size);
+ // without recursion, but tricky to correctly calculate the size of the
+ // binding, which is something we are probably going to need to do later
+ // on. Hopefully nesting of structs in structs too many levels is
+ // unlikely.
+ collectResourcesOnUserRecordDecl(VD, RT);
}
}
}
-// return false if the register binding is not valid
+// Diagnore localized register binding errors for a single binding; does not
+// diagnose resource binding on user record types, that will be done later
+// in processResourceBindingOnDecl based on the information collected in
+// collectResourcesOnVarDecl.
+// Returns false if the register binding is not valid.
static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc,
Decl *D, RegisterType RegType,
bool SpecifiedSpace) {
@@ -1155,7 +1165,7 @@ static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc,
return false;
}
if (Ty->isRecordType())
- // RecordTypes will be diagnosed in ProcessResourceBindingOnDecl
+ // RecordTypes will be diagnosed in processResourceBindingOnDecl
// that is called from ActOnVariableDeclarator
return true;
@@ -1175,7 +1185,7 @@ static bool ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
if (HLSLResourceBindingAttr *attr =
dyn_cast<HLSLResourceBindingAttr>(*it)) {
- RegisterType otherRegType = getRegisterType(attr->getSlot());
+ RegisterType otherRegType = attr->getRegisterType();
if (RegisterTypesDetected[static_cast<int>(otherRegType)]) {
int otherRegTypeNum = static_cast<int>(otherRegType);
S.Diag(TheDecl->getLocation(),
@@ -1250,13 +1260,12 @@ void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {
// Validate.
if (!Slot.empty()) {
- RegType = getRegisterType(Slot);
- if (RegType == RegisterType::I) {
- Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_i);
+ if (!convertToRegisterType(Slot, &RegType)) {
+ Diag(ArgLoc, diag::err_hlsl_binding_type_invalid) << Slot.substr(0, 1);
return;
}
- if (RegType == RegisterType::Invalid) {
- Diag(ArgLoc, diag::err_hlsl_binding_type_invalid) << Slot.substr(0, 1);
+ if (RegType == RegisterType::I) {
+ Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_i);
return;
}
@@ -2294,60 +2303,51 @@ void SemaHLSL::ActOnVariableDeclarator(VarDecl *VD) {
// find all resources on decl
if (IsIntangibleType(VD->getType()))
- FindResourcesOnVarDecl(VD);
+ collectResourcesOnVarDecl(VD);
// process explicit bindings
- ProcessExplicitBindingsOnDecl(VD);
+ processExplicitBindingsOnDecl(VD);
}
}
// Walks though the global variable declaration, collects all resource binding
// requirements and adds them to Bindings
-void SemaHLSL::FindResourcesOnVarDecl(VarDecl *VD) {
+void SemaHLSL::collectResourcesOnVarDecl(VarDecl *VD) {
assert(VD->hasGlobalStorage() && IsIntangibleType(VD->getType()) &&
"expected global variable that contains HLSL resource");
// Cbuffers and Tbuffers are HLSLBufferDecl types
if (const HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(VD)) {
- Bindings.addDeclBindingInfo(VD,
- CBufferOrTBuffer->isCBuffer()
- ? ResourceClass::CBuffer
- : ResourceClass::SRV,
- 1);
+ Bindings.addDeclBindingInfo(VD, CBufferOrTBuffer->isCBuffer()
+ ? ResourceClass::CBuffer
+ : ResourceClass::SRV);
return;
}
- // Calculate size of array and unwrap
- int Size = 1;
+ // Unwrap arrays
+ // FIXME: Calculate array size while unwrapping
const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
- if (Ty->isIncompleteArrayType())
- Size = -1;
while (Ty->isConstantArrayType()) {
const ConstantArrayType *CAT = cast<ConstantArrayType>(Ty);
- Size *= CAT->getSize().getSExtValue();
Ty = CAT->getElementType()->getUnqualifiedDesugaredType();
}
// Resource (or array of resources)
if (const HLSLAttributedResourceType *AttrResType =
FindHandleTypeOnResource(Ty)) {
- Bindings.addDeclBindingInfo(VD, AttrResType->getAttrs().ResourceClass,
- Size);
+ Bindings.addDeclBindingInfo(VD, AttrResType->getAttrs().ResourceClass);
return;
}
- assert(Size != -1 &&
- "unbounded arrays of user defined types are not supported");
-
// User defined record type
if (const RecordType *RT = dyn_cast<RecordType>(Ty))
- FindResourcesOnUserRecordDecl(VD, RT, Size);
+ collectResourcesOnUserRecordDecl(VD, RT);
}
// Walks though the explicit resource binding attributes on the declaration,
// and makes sure there is a resource that matched the binding and updates
// DeclBindingInfoLists
-void SemaHLSL::ProcessExplicitBindingsOnDecl(VarDecl *VD) {
+void SemaHLSL::processExplicitBindingsOnDecl(VarDecl *VD) {
assert(VD->hasGlobalStorage() && "expected global variable");
for (Attr *A : VD->attrs()) {
@@ -2356,11 +2356,9 @@ void SemaHLSL::ProcessExplicitBindingsOnDecl(VarDecl *VD) {
continue;
RegisterType RT = RBA->getRegisterType();
- assert(RT != RegisterType::I && RT != RegisterType::Invalid &&
- "invalid or obsolete register type should never have an attribute "
- "created");
+ assert(RT != RegisterType::I && "invalid or obsolete register type should "
+ "never have an attribute created");
- // These were already diagnosed earlier
if (RT == RegisterType::C) {
if (Bindings.hasBindingInfoForDecl(VD))
SemaRef.Diag(VD->getLocation(),
@@ -2375,7 +2373,6 @@ void SemaHLSL::ProcessExplicitBindingsOnDecl(VarDecl *VD) {
ResourceClass RC = getResourceClass(RT);
if (DeclBindingInfo *BI = Bindings.getDeclBindingInfo(VD, RC)) {
// update binding info
- RBA->setSize(BI->Size);
BI->setBindingAttribute(RBA, BindingType::Explicit);
} else {
SemaRef.Diag(VD->getLocation(),
More information about the cfe-commits
mailing list