[clang] [HLSL] Collect explicit resource binding information (part 1) (PR #111203)
Helena Kotas via cfe-commits
cfe-commits at lists.llvm.org
Fri Oct 4 12:57:23 PDT 2024
https://github.com/hekota created https://github.com/llvm/llvm-project/pull/111203
Adds fields to `HLSLResourceBindingAttr` to store processed binding information. This will be used by CodeGen or Sema for resource initialization or overlapping mapping diagnostic.
Moves binding checks for user defined types (UDTs) to `ProcessResourceBindingOnDecl` (called from ActOnVariableDeclarator), which updated the information in the attribute and where additional processing of the explicit resource binding will be added in the future.
Changed `handleResourceBindingAttr` to not create the resource binding attribute if the local binding diagnostic detects errors.
Part 1 of #110719
>From f545a14e11556c91d10b14617e3588fe5eae6d42 Mon Sep 17 00:00:00 2001
From: Helena Kotas <hekotas at microsoft.com>
Date: Fri, 4 Oct 2024 12:21:51 -0700
Subject: [PATCH] [HLSL] Collect explicit resource binding information (part 1)
- Do not create resource binding attribute if it is not valid
- Store basic resource binding information on HLSLResourceBindingAttr
- Move UDT type checking to to ActOnVariableDeclarator
Part 1 of #110719
---
clang/include/clang/Basic/Attr.td | 29 +++
clang/include/clang/Sema/SemaHLSL.h | 2 +
clang/lib/Sema/SemaDecl.cpp | 3 +
clang/lib/Sema/SemaHLSL.cpp | 227 ++++++++++++------
.../resource_binding_attr_error_udt.hlsl | 8 +-
5 files changed, 188 insertions(+), 81 deletions(-)
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index fbcbf0ed416416..668c599da81390 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4588,6 +4588,35 @@ def HLSLResourceBinding: InheritableAttr {
let LangOpts = [HLSL];
let Args = [StringArgument<"Slot">, StringArgument<"Space", 1>];
let Documentation = [HLSLResourceBindingDocs];
+ let AdditionalMembers = [{
+ enum class RegisterType : unsigned { SRV, UAV, CBuffer, Sampler, C, I, Invalid };
+
+ const FieldDecl *ResourceField = nullptr;
+ RegisterType RegType;
+ unsigned SlotNumber;
+ unsigned SpaceNumber;
+
+ void setBinding(RegisterType RT, unsigned SlotNum, unsigned SpaceNum) {
+ RegType = RT;
+ SlotNumber = SlotNum;
+ SpaceNumber = SpaceNum;
+ }
+ void setResourceField(const FieldDecl *FD) {
+ ResourceField = FD;
+ }
+ const FieldDecl *getResourceField() {
+ return ResourceField;
+ }
+ RegisterType getRegisterType() {
+ return RegType;
+ }
+ unsigned getSlotNumber() {
+ return SlotNumber;
+ }
+ unsigned getSpaceNumber() {
+ return SpaceNumber;
+ }
+ }];
}
def HLSLPackOffset: HLSLAnnotationAttr {
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index fa957abc9791af..018e7ea5901a2b 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -28,6 +28,7 @@ class AttributeCommonInfo;
class IdentifierInfo;
class ParsedAttr;
class Scope;
+class VarDecl;
// FIXME: This can be hidden (as static function in SemaHLSL.cpp) once we no
// longer need to create builtin buffer types in HLSLExternalSemaSource.
@@ -62,6 +63,7 @@ class SemaHLSL : public SemaBase {
const Attr *A, llvm::Triple::EnvironmentType Stage,
std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages);
void DiagnoseAvailabilityViolations(TranslationUnitDecl *TU);
+ void ProcessResourceBindingOnDecl(VarDecl *D);
QualType handleVectorBinOpConversion(ExprResult &LHS, ExprResult &RHS,
QualType LHSType, QualType RHSType,
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index 2bf610746bc317..8e27a5e068e702 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -7876,6 +7876,9 @@ NamedDecl *Sema::ActOnVariableDeclarator(
// Handle attributes prior to checking for duplicates in MergeVarDecl
ProcessDeclAttributes(S, NewVD, D);
+ if (getLangOpts().HLSL)
+ HLSL().ProcessResourceBindingOnDecl(NewVD);
+
// FIXME: This is probably the wrong location to be doing this and we should
// probably be doing this for more attributes (especially for function
// pointer attributes such as format, warn_unused_result, etc.). Ideally
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index fbcba201a351a6..568a8de30c1fc5 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -41,9 +41,7 @@
using namespace clang;
using llvm::dxil::ResourceClass;
-
-enum class RegisterType { SRV, UAV, CBuffer, Sampler, C, I, Invalid };
-
+using RegisterType = HLSLResourceBindingAttr::RegisterType;
static RegisterType getRegisterType(ResourceClass RC) {
switch (RC) {
case ResourceClass::SRV:
@@ -985,44 +983,43 @@ SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) {
return LocInfo;
}
-// get the record decl from a var decl that we expect
-// represents a resource
-static CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *VD) {
- const Type *Ty = VD->getType()->getPointeeOrArrayElementType();
- assert(Ty && "Resource must have an element type.");
-
- if (Ty->isBuiltinType())
- return nullptr;
-
- CXXRecordDecl *TheRecordDecl = Ty->getAsCXXRecordDecl();
- assert(TheRecordDecl && "Resource should have a resource type declaration.");
- return TheRecordDecl;
-}
-
+// Returns handle type of a resource, if the VarDecl is a resource
+// or an array of resources
static const HLSLAttributedResourceType *
-findAttributedResourceTypeOnField(VarDecl *VD) {
+FindHandleTypeOnResource(const VarDecl *VD) {
+ // If VarDecl is a resource class, the first field must
+ // be the resource handle of type HLSLAttributedResourceType
assert(VD != nullptr && "expected VarDecl");
- if (RecordDecl *RD = getRecordDeclFromVarDecl(VD)) {
- for (auto *FD : RD->fields()) {
- if (const HLSLAttributedResourceType *AttrResType =
- dyn_cast<HLSLAttributedResourceType>(FD->getType().getTypePtr()))
- return AttrResType;
+ const Type *Ty = VD->getType()->getPointeeOrArrayElementType();
+ if (RecordDecl *RD = Ty->getAsCXXRecordDecl()) {
+ if (!RD->fields().empty()) {
+ const auto &FirstFD = RD->fields().begin();
+ return dyn_cast<HLSLAttributedResourceType>(
+ FirstFD->getType().getTypePtr());
}
}
return nullptr;
}
-// Iterate over RecordType fields and return true if any of them matched the
-// register type
-static bool ContainsResourceForRegisterType(Sema &S, const RecordType *RT,
- RegisterType RegType) {
+// Walks though the user defined record type, finds resource class
+// that matches the RegisterBinding.Type and assigns it to
+// RegisterBinding::Decl.
+static bool
+ProcessResourceBindingOnUserRecordDecl(const RecordType *RT,
+ HLSLResourceBindingAttr *RBA) {
+
llvm::SmallVector<const Type *> TypesToScan;
TypesToScan.emplace_back(RT);
+ RegisterType RegType = RBA->getRegisterType();
while (!TypesToScan.empty()) {
const Type *T = TypesToScan.pop_back_val();
- while (T->isArrayType())
+
+ while (T->isArrayType()) {
+ // FIXME: calculate the binding size from the array dimensions (or
+ // unbounded for unsized array) size *= (size_of_array);
T = T->getArrayElementTypeNoTypeQual();
+ }
if (T->isIntegralOrEnumerationType() || T->isFloatingType()) {
if (RegType == RegisterType::C)
return true;
@@ -1037,8 +1034,12 @@ static bool ContainsResourceForRegisterType(Sema &S, const RecordType *RT,
if (const HLSLAttributedResourceType *AttrResType =
dyn_cast<HLSLAttributedResourceType>(FieldTy)) {
ResourceClass RC = AttrResType->getAttrs().ResourceClass;
- if (getRegisterType(RC) == RegType)
+ if (getRegisterType(RC) == RegType) {
+ assert(RBA->getResourceField() == nullptr &&
+ "multiple register bindings of the same type are not allowed");
+ RBA->setResourceField(FD);
return true;
+ }
} else {
TypesToScan.emplace_back(FD->getType().getTypePtr());
}
@@ -1047,26 +1048,28 @@ static bool ContainsResourceForRegisterType(Sema &S, const RecordType *RT,
return false;
}
-static void CheckContainsResourceForRegisterType(Sema &S,
- SourceLocation &ArgLoc,
- Decl *D, RegisterType RegType,
- bool SpecifiedSpace) {
+// return false if the register binding is not valid
+static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc,
+ Decl *D, RegisterType RegType,
+ bool SpecifiedSpace) {
int RegTypeNum = static_cast<int>(RegType);
// check if the decl type is groupshared
if (D->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) {
S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
- return;
+ return false;
}
// Cbuffers and Tbuffers are HLSLBufferDecl types
if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) {
ResourceClass RC = CBufferOrTBuffer->isCBuffer() ? ResourceClass::CBuffer
: ResourceClass::SRV;
- if (RegType != getRegisterType(RC))
- S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
- << RegTypeNum;
- return;
+ if (RegType == getRegisterType(RC))
+ return true;
+
+ S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
+ << RegTypeNum;
+ return false;
}
// Samplers, UAVs, and SRVs are VarDecl types
@@ -1075,11 +1078,13 @@ static void CheckContainsResourceForRegisterType(Sema &S,
// Resource
if (const HLSLAttributedResourceType *AttrResType =
- findAttributedResourceTypeOnField(VD)) {
- if (RegType != getRegisterType(AttrResType->getAttrs().ResourceClass))
- S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
- << RegTypeNum;
- return;
+ FindHandleTypeOnResource(VD)) {
+ if (RegType == getRegisterType(AttrResType->getAttrs().ResourceClass))
+ return true;
+
+ S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
+ << RegTypeNum;
+ return false;
}
const clang::Type *Ty = VD->getType().getTypePtr();
@@ -1088,36 +1093,43 @@ static void CheckContainsResourceForRegisterType(Sema &S,
// Basic types
if (Ty->isArithmeticType()) {
+ bool IsValid = true;
bool DeclaredInCOrTBuffer = isa<HLSLBufferDecl>(D->getDeclContext());
- if (SpecifiedSpace && !DeclaredInCOrTBuffer)
+ if (SpecifiedSpace && !DeclaredInCOrTBuffer) {
S.Diag(ArgLoc, diag::err_hlsl_space_on_global_constant);
+ IsValid = false;
+ }
if (!DeclaredInCOrTBuffer &&
(Ty->isIntegralType(S.getASTContext()) || Ty->isFloatingType())) {
// Default Globals
if (RegType == RegisterType::CBuffer)
S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b);
- else if (RegType != RegisterType::C)
+ else if (RegType != RegisterType::C) {
S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
+ IsValid = false;
+ }
} else {
if (RegType == RegisterType::C)
S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset);
- else
+ else {
S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
+ IsValid = false;
+ }
}
- } else if (Ty->isRecordType()) {
- // Class/struct types - walk the declaration and check each field and
- // subclass
- if (!ContainsResourceForRegisterType(S, Ty->getAs<RecordType>(), RegType))
- S.Diag(D->getLocation(), diag::warn_hlsl_user_defined_type_missing_member)
- << RegTypeNum;
- } else {
- // Anything else is an error
- S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
+ return IsValid;
}
+ if (Ty->isRecordType())
+ // RecordTypes will be diagnosed in ProcessResourceBindingOnDecl
+ // that is called from ActOnVariableDeclarator
+ return true;
+
+ // Anything else is an error
+ S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
+ return false;
}
-static void ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
+static bool ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
RegisterType regType) {
// make sure that there are no two register annotations
// applied to the decl with the same register type
@@ -1135,21 +1147,19 @@ static void ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
RegisterType otherRegType = getRegisterType(attr->getSlot());
if (RegisterTypesDetected[static_cast<int>(otherRegType)]) {
- if (PreviousConflicts[TheDecl].count(otherRegType))
- continue;
int otherRegTypeNum = static_cast<int>(otherRegType);
S.Diag(TheDecl->getLocation(),
diag::err_hlsl_duplicate_register_annotation)
<< otherRegTypeNum;
- PreviousConflicts[TheDecl].insert(otherRegType);
- } else {
- RegisterTypesDetected[static_cast<int>(otherRegType)] = true;
+ return false;
}
+ RegisterTypesDetected[static_cast<int>(otherRegType)] = true;
}
}
+ return true;
}
-static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
+static bool DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
Decl *D, RegisterType RegType,
bool SpecifiedSpace) {
@@ -1159,10 +1169,11 @@ static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
"expecting VarDecl or HLSLBufferDecl");
// check if the declaration contains resource matching the register type
- CheckContainsResourceForRegisterType(S, ArgLoc, D, RegType, SpecifiedSpace);
+ if (!DiagnoseLocalRegisterBinding(S, ArgLoc, D, RegType, SpecifiedSpace))
+ return false;
// next, if multiple register annotations exist, check that none conflict.
- ValidateMultipleRegisterAnnotations(S, D, RegType);
+ return ValidateMultipleRegisterAnnotations(S, D, RegType);
}
void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {
@@ -1203,23 +1214,24 @@ void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {
Slot = Str;
}
- RegisterType regType;
+ RegisterType RegType;
+ unsigned SlotNum = 0;
+ unsigned SpaceNum = 0;
// Validate.
if (!Slot.empty()) {
- regType = getRegisterType(Slot);
- if (regType == RegisterType::I) {
+ RegType = getRegisterType(Slot);
+ if (RegType == RegisterType::I) {
Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_i);
return;
}
- if (regType == RegisterType::Invalid) {
+ if (RegType == RegisterType::Invalid) {
Diag(ArgLoc, diag::err_hlsl_binding_type_invalid) << Slot.substr(0, 1);
return;
}
- StringRef SlotNum = Slot.substr(1);
- unsigned Num = 0;
- if (SlotNum.getAsInteger(10, Num)) {
+ StringRef SlotNumStr = Slot.substr(1);
+ if (SlotNumStr.getAsInteger(10, SlotNum)) {
Diag(ArgLoc, diag::err_hlsl_unsupported_register_number);
return;
}
@@ -1229,20 +1241,22 @@ void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {
Diag(SpaceArgLoc, diag::err_hlsl_expected_space) << Space;
return;
}
- StringRef SpaceNum = Space.substr(5);
- unsigned Num = 0;
- if (SpaceNum.getAsInteger(10, Num)) {
+ StringRef SpaceNumStr = Space.substr(5);
+ if (SpaceNumStr.getAsInteger(10, SpaceNum)) {
Diag(SpaceArgLoc, diag::err_hlsl_expected_space) << Space;
return;
}
- DiagnoseHLSLRegisterAttribute(SemaRef, ArgLoc, TheDecl, regType,
- SpecifiedSpace);
+ if (!DiagnoseHLSLRegisterAttribute(SemaRef, ArgLoc, TheDecl, RegType,
+ SpecifiedSpace))
+ return;
HLSLResourceBindingAttr *NewAttr =
HLSLResourceBindingAttr::Create(getASTContext(), Slot, Space, AL);
- if (NewAttr)
+ if (NewAttr) {
+ NewAttr->setBinding(RegType, SlotNum, SpaceNum);
TheDecl->addAttr(NewAttr);
+ }
}
void SemaHLSL::handleParamModifierAttr(Decl *D, const ParsedAttr &AL) {
@@ -2228,3 +2242,62 @@ QualType SemaHLSL::getInoutParameterType(QualType Ty) {
Ty.addRestrict();
return Ty;
}
+
+// Walks though existing explicit bindings, finds the actual resource class
+// decl the binding applies to and sets it to attr->ResourceField.
+// Additional processing of resource binding can be added here later on,
+// such as preparation for overapping resource detection or implicit binding.
+void SemaHLSL::ProcessResourceBindingOnDecl(VarDecl *D) {
+ if (!D->hasGlobalStorage())
+ return;
+
+ for (Attr *A : D->attrs()) {
+ HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(A);
+ if (!RBA)
+ continue;
+
+ // // Cbuffers and Tbuffers are HLSLBufferDecl types
+ if (const HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) {
+ assert(RBA->getRegisterType() ==
+ getRegisterType(CBufferOrTBuffer->isCBuffer()
+ ? ResourceClass::CBuffer
+ : ResourceClass::SRV) &&
+ "this should have been handled in DiagnoseLocalRegisterBinding");
+ // should we handle HLSLBufferDecl here?
+ continue;
+ }
+
+ // Samplers, UAVs, and SRVs are VarDecl types
+ assert(isa<VarDecl>(D) && "D is expected to be VarDecl or HLSLBufferDecl");
+ const VarDecl *VD = cast<VarDecl>(D);
+
+ // Register binding directly on global resource class variable
+ if (const HLSLAttributedResourceType *AttrResType =
+ FindHandleTypeOnResource(VD)) {
+ // FIXME: if array, calculate the binding size from the array dimensions
+ // (or unbounded for unsized array)
+ assert(RBA->getResourceField() == nullptr);
+ continue;
+ }
+
+ // Global array
+ const clang::Type *Ty = VD->getType().getTypePtr();
+ while (Ty->isArrayType()) {
+ Ty = Ty->getArrayElementTypeNoTypeQual();
+ }
+
+ // Basic types
+ if (Ty->isArithmeticType()) {
+ continue;
+ }
+
+ if (Ty->isRecordType()) {
+ if (!ProcessResourceBindingOnUserRecordDecl(Ty->getAs<RecordType>(),
+ RBA)) {
+ SemaRef.Diag(D->getLocation(),
+ diag::warn_hlsl_user_defined_type_missing_member)
+ << static_cast<int>(RBA->getRegisterType());
+ }
+ }
+ }
+}
diff --git a/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl b/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl
index ea2d576e4cca55..40517f393e1284 100644
--- a/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl
+++ b/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl
@@ -106,7 +106,6 @@ struct Eg12{
MySRV s1;
MySRV s2;
};
-// expected-warning at +3{{binding type 'u' only applies to types containing UAV resources}}
// expected-warning at +2{{binding type 'u' only applies to types containing UAV resources}}
// expected-error at +1{{binding type 'u' cannot be applied more than once}}
Eg12 e12 : register(u9) : register(u10);
@@ -115,12 +114,14 @@ struct Eg13{
MySRV s1;
MySRV s2;
};
-// expected-warning at +4{{binding type 'u' only applies to types containing UAV resources}}
// expected-warning at +3{{binding type 'u' only applies to types containing UAV resources}}
-// expected-warning at +2{{binding type 'u' only applies to types containing UAV resources}}
+// expected-error at +2{{binding type 'u' cannot be applied more than once}}
// expected-error at +1{{binding type 'u' cannot be applied more than once}}
Eg13 e13 : register(u9) : register(u10) : register(u11);
+// expected-error at +1{{binding type 't' cannot be applied more than once}}
+Eg13 e13_2 : register(t11) : register(t12);
+
struct Eg14{
MyTemplatedUAV<int> r1;
};
@@ -132,4 +133,3 @@ struct Eg15 {
};
// expected no error
Eg15 e15 : register(c0);
-
More information about the cfe-commits
mailing list