[clang] [HLSL] Adjust resource binding diagnostic flags code (PR #106657)
via cfe-commits
cfe-commits at lists.llvm.org
Thu Aug 29 20:54:30 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-clang
Author: Helena Kotas (hekota)
<details>
<summary>Changes</summary>
Adjust register binding diagnostic flags code in a couple of ways:
- Store the resource class in the Flags struct to avoid duplicated scanning for HLSLResourceClassAttribute
- Avoid unnecessary indirection when converting resource class to register type
- Remove recursion and reduce duplicated code
Also fixes a case where struct with an array was incorrectly diagnosed unfit for `c` register binding.
This will also simplify work that is needed to be done in this area for llvm/llvm-project#<!-- -->104861.
---
Full diff: https://github.com/llvm/llvm-project/pull/106657.diff
2 Files Affected:
- (modified) clang/lib/Sema/SemaHLSL.cpp (+68-113)
- (modified) clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl (+7)
``````````diff
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 714e8f5cfa9926..1e484f754b931d 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -480,6 +480,9 @@ struct RegisterBindingFlags {
bool ContainsNumeric = false;
bool DefaultGlobals = false;
+
+ // used only when Resource == true
+ llvm::dxil::ResourceClass ResourceClass = llvm::dxil::ResourceClass::UAV;
};
static bool isDeclaredWithinCOrTBuffer(const Decl *TheDecl) {
@@ -545,65 +548,38 @@ static const T *getSpecifiedHLSLAttrFromVarDecl(VarDecl *VD) {
return getSpecifiedHLSLAttrFromRecordDecl<T>(TheRecordDecl);
}
-static void updateFlagsFromType(QualType TheQualTy,
- RegisterBindingFlags &Flags);
-
-static void updateResourceClassFlagsFromRecordDecl(RegisterBindingFlags &Flags,
- const RecordDecl *RD) {
- if (!RD)
- return;
-
- if (RD->isCompleteDefinition()) {
- for (auto Field : RD->fields()) {
- QualType T = Field->getType();
- updateFlagsFromType(T, Flags);
+static void updateResourceClassFlagsFromRecordType(RegisterBindingFlags &Flags,
+ const RecordType *RT) {
+ llvm::SmallVector<const Type *> TypesToScan;
+ TypesToScan.emplace_back(RT);
+
+ while (!TypesToScan.empty()) {
+ const Type *T = TypesToScan.pop_back_val();
+ while (T->isArrayType())
+ T = T->getArrayElementTypeNoTypeQual();
+ if (T->isIntegralOrEnumerationType() || T->isFloatingType()) {
+ Flags.ContainsNumeric = true;
+ continue;
}
- }
-}
-
-static void updateFlagsFromType(QualType TheQualTy,
- RegisterBindingFlags &Flags) {
- // if the member's type is a numeric type, set the ContainsNumeric flag
- if (TheQualTy->isIntegralOrEnumerationType() || TheQualTy->isFloatingType()) {
- Flags.ContainsNumeric = true;
- return;
- }
-
- const clang::Type *TheBaseType = TheQualTy.getTypePtr();
- while (TheBaseType->isArrayType())
- TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual();
- // otherwise, if the member's base type is not a record type, return
- const RecordType *TheRecordTy = TheBaseType->getAs<RecordType>();
- if (!TheRecordTy)
- return;
-
- RecordDecl *SubRecordDecl = TheRecordTy->getDecl();
- const HLSLResourceClassAttr *Attr =
- getSpecifiedHLSLAttrFromRecordDecl<HLSLResourceClassAttr>(SubRecordDecl);
- // find the attr if it's on the member, or on any of the member's fields
- if (Attr) {
- llvm::hlsl::ResourceClass DeclResourceClass = Attr->getResourceClass();
- updateResourceClassFlagsFromDeclResourceClass(Flags, DeclResourceClass);
- }
+ const RecordType *RT = T->getAs<RecordType>();
+ if (!RT)
+ continue;
- // otherwise, dig deeper and recurse into the member
- else {
- updateResourceClassFlagsFromRecordDecl(Flags, SubRecordDecl);
+ const RecordDecl *RD = RT->getDecl();
+ for (FieldDecl *FD : RD->fields()) {
+ if (HLSLResourceClassAttr *RCAttr =
+ FD->getAttr<HLSLResourceClassAttr>()) {
+ updateResourceClassFlagsFromDeclResourceClass(
+ Flags, RCAttr->getResourceClass());
+ continue;
+ }
+ TypesToScan.emplace_back(FD->getType().getTypePtr());
+ }
}
}
static RegisterBindingFlags HLSLFillRegisterBindingFlags(Sema &S,
Decl *TheDecl) {
-
- // Cbuffers and Tbuffers are HLSLBufferDecl types
- HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl);
- // Samplers, UAVs, and SRVs are VarDecl types
- VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl);
-
- assert(((TheVarDecl && !CBufferOrTBuffer) ||
- (!TheVarDecl && CBufferOrTBuffer)) &&
- "either TheVarDecl or CBufferOrTBuffer should be set");
-
RegisterBindingFlags Flags;
// check if the decl type is groupshared
@@ -612,57 +588,61 @@ static RegisterBindingFlags HLSLFillRegisterBindingFlags(Sema &S,
return Flags;
}
- if (!isDeclaredWithinCOrTBuffer(TheDecl)) {
- // make sure the type is a basic / numeric type
- if (TheVarDecl) {
- QualType TheQualTy = TheVarDecl->getType();
- // a numeric variable or an array of numeric variables
- // will inevitably end up in $Globals buffer
- const clang::Type *TheBaseType = TheQualTy.getTypePtr();
- while (TheBaseType->isArrayType())
- TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual();
- if (TheBaseType->isIntegralType(S.getASTContext()) ||
- TheBaseType->isFloatingType())
- Flags.DefaultGlobals = true;
- }
- }
-
- if (CBufferOrTBuffer) {
+ // Cbuffers and Tbuffers are HLSLBufferDecl types
+ if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl)) {
Flags.Resource = true;
- if (CBufferOrTBuffer->isCBuffer())
- Flags.CBV = true;
- else
- Flags.SRV = true;
- } else if (TheVarDecl) {
+ Flags.ResourceClass = CBufferOrTBuffer->isCBuffer()
+ ? llvm::dxil::ResourceClass::CBuffer
+ : llvm::dxil::ResourceClass::SRV;
+ }
+ // Samplers, UAVs, and SRVs are VarDecl types
+ else if (VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl)) {
const HLSLResourceClassAttr *resClassAttr =
getSpecifiedHLSLAttrFromVarDecl<HLSLResourceClassAttr>(TheVarDecl);
-
if (resClassAttr) {
- llvm::hlsl::ResourceClass DeclResourceClass =
- resClassAttr->getResourceClass();
Flags.Resource = true;
- updateResourceClassFlagsFromDeclResourceClass(Flags, DeclResourceClass);
+ Flags.ResourceClass = resClassAttr->getResourceClass();
} else {
const clang::Type *TheBaseType = TheVarDecl->getType().getTypePtr();
while (TheBaseType->isArrayType())
TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual();
- if (TheBaseType->isArithmeticType())
+
+ if (TheBaseType->isArithmeticType()) {
Flags.Basic = true;
- else if (TheBaseType->isRecordType()) {
+ if (!isDeclaredWithinCOrTBuffer(TheDecl) &&
+ (TheBaseType->isIntegralType(S.getASTContext()) ||
+ TheBaseType->isFloatingType()))
+ Flags.DefaultGlobals = true;
+ } else if (TheBaseType->isRecordType()) {
Flags.UDT = true;
const RecordType *TheRecordTy = TheBaseType->getAs<RecordType>();
- assert(TheRecordTy && "The Qual Type should be Record Type");
- const RecordDecl *TheRecordDecl = TheRecordTy->getDecl();
- // recurse through members, set appropriate resource class flags.
- updateResourceClassFlagsFromRecordDecl(Flags, TheRecordDecl);
+ updateResourceClassFlagsFromRecordType(Flags, TheRecordTy);
} else
Flags.Other = true;
}
+ } else {
+ llvm_unreachable("expected be VarDecl or HLSLBufferDecl");
}
return Flags;
}
-enum class RegisterType { SRV, UAV, CBuffer, Sampler, C, I, Invalid };
+enum class RegisterType {
+ SRV = static_cast<int>(llvm::dxil::ResourceClass::SRV),
+ UAV = static_cast<int>(llvm::dxil::ResourceClass::UAV),
+ CBuffer = static_cast<int>(llvm::dxil::ResourceClass::CBuffer),
+ Sampler = static_cast<int>(llvm::dxil::ResourceClass::Sampler),
+ C,
+ I,
+ Invalid
+};
+
+static RegisterType
+convertResourceClassToRegisterType(llvm::dxil::ResourceClass RC) {
+ assert(RC >= llvm::dxil::ResourceClass::SRV &&
+ RC <= llvm::dxil::ResourceClass::Sampler &&
+ "unexpected resource class value");
+ return static_cast<RegisterType>(RC);
+}
static RegisterType getRegisterType(StringRef Slot) {
switch (Slot[0]) {
@@ -754,34 +734,9 @@ static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
// next, if resource is set, make sure the register type in the register
// annotation is compatible with the variable's resource type.
if (Flags.Resource) {
- const HLSLResourceClassAttr *resClassAttr = nullptr;
- if (CBufferOrTBuffer) {
- resClassAttr = CBufferOrTBuffer->getAttr<HLSLResourceClassAttr>();
- } else if (TheVarDecl) {
- resClassAttr =
- getSpecifiedHLSLAttrFromVarDecl<HLSLResourceClassAttr>(TheVarDecl);
- }
-
- assert(resClassAttr &&
- "any decl that set the resource flag on analysis should "
- "have a resource class attribute attached.");
- const llvm::hlsl::ResourceClass DeclResourceClass =
- resClassAttr->getResourceClass();
-
- // confirm that the register type is bound to its expected resource class
- static RegisterType ExpectedRegisterTypesForResourceClass[] = {
- RegisterType::SRV,
- RegisterType::UAV,
- RegisterType::CBuffer,
- RegisterType::Sampler,
- };
- assert((size_t)DeclResourceClass <
- std::size(ExpectedRegisterTypesForResourceClass) &&
- "DeclResourceClass has unexpected value");
-
- RegisterType ExpectedRegisterType =
- ExpectedRegisterTypesForResourceClass[(int)DeclResourceClass];
- if (regType != ExpectedRegisterType) {
+ RegisterType expRegType =
+ convertResourceClassToRegisterType(Flags.ResourceClass);
+ if (regType != expRegType) {
S.Diag(TheDecl->getLocation(), diag::err_hlsl_binding_type_mismatch)
<< regTypeNum;
}
@@ -823,7 +778,7 @@ static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
}
void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {
- if (dyn_cast<VarDecl>(TheDecl)) {
+ if (isa<VarDecl>(TheDecl)) {
if (SemaRef.RequireCompleteType(TheDecl->getBeginLoc(),
cast<ValueDecl>(TheDecl)->getType(),
diag::err_incomplete_type))
diff --git a/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl b/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl
index f8e38b6d2851d9..edb3f30739cdfd 100644
--- a/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl
+++ b/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl
@@ -126,3 +126,10 @@ struct Eg14{
};
// expected-warning at +1{{binding type 't' only applies to types containing SRV resources}}
Eg14 e14 : register(t9);
+
+struct Eg15 {
+ float f[4];
+};
+// expected no error
+Eg15 e15 : register(c0);
+
``````````
</details>
https://github.com/llvm/llvm-project/pull/106657
More information about the cfe-commits
mailing list