[clang] [HLSL] Use llvm::Triple::EnvironmentType instead of HLSLShaderAttr::ShaderType (PR #93847)
via cfe-commits
cfe-commits at lists.llvm.org
Thu May 30 10:29:08 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-clang-codegen
Author: Helena Kotas (hekota)
<details>
<summary>Changes</summary>
`HLSLShaderAttr::ShaderType` enum is a subset of `llvm::Triple::EnvironmentType`. We can use `llvm::Triple::EnvironmentType` directly and avoid converting one enum to another.
---
Full diff: https://github.com/llvm/llvm-project/pull/93847.diff
5 Files Affected:
- (modified) clang/include/clang/Basic/Attr.td (+5-21)
- (modified) clang/include/clang/Sema/SemaHLSL.h (+3-3)
- (modified) clang/lib/CodeGen/CGHLSLRuntime.cpp (+1-1)
- (modified) clang/lib/Sema/SemaDeclAttr.cpp (+2-2)
- (modified) clang/lib/Sema/SemaHLSL.cpp (+58-47)
``````````diff
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index 2665b7353ca4a..a337509d3e2b5 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4469,36 +4469,20 @@ def HLSLShader : InheritableAttr {
let Subjects = SubjectList<[HLSLEntry]>;
let LangOpts = [HLSL];
let Args = [
- EnumArgument<"Type", "ShaderType", /*is_string=*/true,
+ EnumArgument<"Type", "llvm::Triple::EnvironmentType", /*is_string=*/true,
["pixel", "vertex", "geometry", "hull", "domain", "compute",
"raygeneration", "intersection", "anyhit", "closesthit",
"miss", "callable", "mesh", "amplification"],
["Pixel", "Vertex", "Geometry", "Hull", "Domain", "Compute",
"RayGeneration", "Intersection", "AnyHit", "ClosestHit",
- "Miss", "Callable", "Mesh", "Amplification"]>
+ "Miss", "Callable", "Mesh", "Amplification"],
+ /*opt=*/0, /*fake=*/0, /*isExternalType=*/1>
];
let Documentation = [HLSLSV_ShaderTypeAttrDocs];
let AdditionalMembers =
[{
- static const unsigned ShaderTypeMaxValue = (unsigned)HLSLShaderAttr::Amplification;
-
- static llvm::Triple::EnvironmentType getTypeAsEnvironment(HLSLShaderAttr::ShaderType ShaderType) {
- switch (ShaderType) {
- case HLSLShaderAttr::Pixel: return llvm::Triple::Pixel;
- case HLSLShaderAttr::Vertex: return llvm::Triple::Vertex;
- case HLSLShaderAttr::Geometry: return llvm::Triple::Geometry;
- case HLSLShaderAttr::Hull: return llvm::Triple::Hull;
- case HLSLShaderAttr::Domain: return llvm::Triple::Domain;
- case HLSLShaderAttr::Compute: return llvm::Triple::Compute;
- case HLSLShaderAttr::RayGeneration: return llvm::Triple::RayGeneration;
- case HLSLShaderAttr::Intersection: return llvm::Triple::Intersection;
- case HLSLShaderAttr::AnyHit: return llvm::Triple::AnyHit;
- case HLSLShaderAttr::ClosestHit: return llvm::Triple::ClosestHit;
- case HLSLShaderAttr::Miss: return llvm::Triple::Miss;
- case HLSLShaderAttr::Callable: return llvm::Triple::Callable;
- case HLSLShaderAttr::Mesh: return llvm::Triple::Mesh;
- case HLSLShaderAttr::Amplification: return llvm::Triple::Amplification;
- }
+ static bool isValidShaderType(llvm::Triple::EnvironmentType ShaderType) {
+ return ShaderType >= llvm::Triple::Pixel && ShaderType <= llvm::Triple::Amplification;
}
}];
}
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index eac1f7c07c85d..00df6c2bd15e4 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -38,7 +38,7 @@ class SemaHLSL : public SemaBase {
const AttributeCommonInfo &AL, int X,
int Y, int Z);
HLSLShaderAttr *mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
- HLSLShaderAttr::ShaderType ShaderType);
+ llvm::Triple::EnvironmentType ShaderType);
HLSLParamModifierAttr *
mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
HLSLParamModifierAttr::Spelling Spelling);
@@ -47,8 +47,8 @@ class SemaHLSL : public SemaBase {
void CheckSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param,
const HLSLAnnotationAttr *AnnotationAttr);
void DiagnoseAttrStageMismatch(
- const Attr *A, HLSLShaderAttr::ShaderType Stage,
- std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages);
+ const Attr *A, llvm::Triple::EnvironmentType Stage,
+ std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages);
void DiagnoseAvailabilityViolations(TranslationUnitDecl *TU);
};
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp
index 5e6a3dd4878f4..55ba21ae2ba69 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.cpp
+++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp
@@ -313,7 +313,7 @@ void clang::CodeGen::CGHLSLRuntime::setHLSLEntryAttributes(
assert(ShaderAttr && "All entry functions must have a HLSLShaderAttr");
const StringRef ShaderAttrKindStr = "hlsl.shader";
Fn->addFnAttr(ShaderAttrKindStr,
- ShaderAttr->ConvertShaderTypeToStr(ShaderAttr->getType()));
+ llvm::Triple::getEnvironmentTypeName(ShaderAttr->getType()));
if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr<HLSLNumThreadsAttr>()) {
const StringRef NumThreadsKindStr = "hlsl.numthreads";
std::string NumThreadsStr =
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index 7c1fb23b90728..49c9de73aafb5 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -7341,8 +7341,8 @@ static void handleHLSLShaderAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
if (!S.checkStringLiteralArgumentAttr(AL, 0, Str, &ArgLoc))
return;
- HLSLShaderAttr::ShaderType ShaderType;
- if (!HLSLShaderAttr::ConvertStrToShaderType(Str, ShaderType)) {
+ llvm::Triple::EnvironmentType ShaderType;
+ if (!HLSLShaderAttr::ConvertStrToEnvironmentType(Str, ShaderType)) {
S.Diag(AL.getLoc(), diag::warn_attribute_type_not_supported)
<< AL << Str << ArgLoc;
return;
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 9e614ae99f37d..da9bda3eaf3d9 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -145,7 +145,7 @@ HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D,
HLSLShaderAttr *
SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
- HLSLShaderAttr::ShaderType ShaderType) {
+ llvm::Triple::EnvironmentType ShaderType) {
if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) {
if (NT->getType() != ShaderType) {
Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
@@ -183,13 +183,12 @@ void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {
if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry)
return;
- StringRef Env = TargetInfo.getTriple().getEnvironmentName();
- HLSLShaderAttr::ShaderType ShaderType;
- if (HLSLShaderAttr::ConvertStrToShaderType(Env, ShaderType)) {
+ llvm::Triple::EnvironmentType Env = TargetInfo.getTriple().getEnvironment();
+ if (HLSLShaderAttr::isValidShaderType(Env) && Env != llvm::Triple::Library) {
if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) {
// The entry point is already annotated - check that it matches the
// triple.
- if (Shader->getType() != ShaderType) {
+ if (Shader->getType() != Env) {
Diag(Shader->getLocation(), diag::err_hlsl_entry_shader_attr_mismatch)
<< Shader;
FD->setInvalidDecl();
@@ -197,11 +196,11 @@ void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {
} else {
// Implicitly add the shader attribute if the entry function isn't
// explicitly annotated.
- FD->addAttr(HLSLShaderAttr::CreateImplicit(getASTContext(), ShaderType,
+ FD->addAttr(HLSLShaderAttr::CreateImplicit(getASTContext(), Env,
FD->getBeginLoc()));
}
} else {
- switch (TargetInfo.getTriple().getEnvironment()) {
+ switch (Env) {
case llvm::Triple::UnknownEnvironment:
case llvm::Triple::Library:
break;
@@ -214,38 +213,40 @@ void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {
void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
assert(ShaderAttr && "Entry point has no shader attribute");
- HLSLShaderAttr::ShaderType ST = ShaderAttr->getType();
+ llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
switch (ST) {
- case HLSLShaderAttr::Pixel:
- case HLSLShaderAttr::Vertex:
- case HLSLShaderAttr::Geometry:
- case HLSLShaderAttr::Hull:
- case HLSLShaderAttr::Domain:
- case HLSLShaderAttr::RayGeneration:
- case HLSLShaderAttr::Intersection:
- case HLSLShaderAttr::AnyHit:
- case HLSLShaderAttr::ClosestHit:
- case HLSLShaderAttr::Miss:
- case HLSLShaderAttr::Callable:
+ case llvm::Triple::Pixel:
+ case llvm::Triple::Vertex:
+ case llvm::Triple::Geometry:
+ case llvm::Triple::Hull:
+ case llvm::Triple::Domain:
+ case llvm::Triple::RayGeneration:
+ case llvm::Triple::Intersection:
+ case llvm::Triple::AnyHit:
+ case llvm::Triple::ClosestHit:
+ case llvm::Triple::Miss:
+ case llvm::Triple::Callable:
if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) {
DiagnoseAttrStageMismatch(NT, ST,
- {HLSLShaderAttr::Compute,
- HLSLShaderAttr::Amplification,
- HLSLShaderAttr::Mesh});
+ {llvm::Triple::Compute,
+ llvm::Triple::Amplification,
+ llvm::Triple::Mesh});
FD->setInvalidDecl();
}
break;
- case HLSLShaderAttr::Compute:
- case HLSLShaderAttr::Amplification:
- case HLSLShaderAttr::Mesh:
+ case llvm::Triple::Compute:
+ case llvm::Triple::Amplification:
+ case llvm::Triple::Mesh:
if (!FD->hasAttr<HLSLNumThreadsAttr>()) {
Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads)
- << HLSLShaderAttr::ConvertShaderTypeToStr(ST);
+ << llvm::Triple::getEnvironmentTypeName(ST);
FD->setInvalidDecl();
}
break;
+ default:
+ llvm_unreachable("Unhandled environment in triple");
}
for (ParmVarDecl *Param : FD->parameters()) {
@@ -267,14 +268,14 @@ void SemaHLSL::CheckSemanticAnnotation(
const HLSLAnnotationAttr *AnnotationAttr) {
auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>();
assert(ShaderAttr && "Entry point has no shader attribute");
- HLSLShaderAttr::ShaderType ST = ShaderAttr->getType();
+ llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
switch (AnnotationAttr->getKind()) {
case attr::HLSLSV_DispatchThreadID:
case attr::HLSLSV_GroupIndex:
- if (ST == HLSLShaderAttr::Compute)
+ if (ST == llvm::Triple::Compute)
return;
- DiagnoseAttrStageMismatch(AnnotationAttr, ST, {HLSLShaderAttr::Compute});
+ DiagnoseAttrStageMismatch(AnnotationAttr, ST, {llvm::Triple::Compute});
break;
default:
llvm_unreachable("Unknown HLSLAnnotationAttr");
@@ -282,16 +283,16 @@ void SemaHLSL::CheckSemanticAnnotation(
}
void SemaHLSL::DiagnoseAttrStageMismatch(
- const Attr *A, HLSLShaderAttr::ShaderType Stage,
- std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages) {
+ const Attr *A, llvm::Triple::EnvironmentType Stage,
+ std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages) {
SmallVector<StringRef, 8> StageStrings;
llvm::transform(AllowedStages, std::back_inserter(StageStrings),
- [](HLSLShaderAttr::ShaderType ST) {
+ [](llvm::Triple::EnvironmentType ST) {
return StringRef(
- HLSLShaderAttr::ConvertShaderTypeToStr(ST));
+ HLSLShaderAttr::ConvertEnvironmentTypeToStr(ST));
});
Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
- << A << HLSLShaderAttr::ConvertShaderTypeToStr(Stage)
+ << A << llvm::Triple::getEnvironmentTypeName(Stage)
<< (AllowedStages.size() != 1) << join(StageStrings, ", ");
}
@@ -321,16 +322,22 @@ class DiagnoseHLSLAvailability
//
// Maps FunctionDecl to an unsigned number that represents the set of shader
// environments the function has been scanned for.
- // Since HLSLShaderAttr::ShaderType enum is generated from Attr.td and is
- // defined without any assigned values, it is guaranteed to be numbered
- // sequentially from 0 up and we can use it to 'index' individual bits
- // in the set.
+ // The llvm::Triple::EnvironmentType enum values for shader stages guaranteed
+ // to be numbered from llvm::Triple::Pixel to llvm::Triple::Amplification
+ // (verified by static_asserts in Triple.cpp), we can use it to index
+ // individual bits in the set, as long as we shift the values to start with 0
+ // by subtracting the value of llvm::Triple::Pixel first.
+ //
// The N'th bit in the set will be set if the function has been scanned
- // in shader environment whose ShaderType integer value equals N.
+ // in shader environment whose llvm::Triple::EnvironmentType integer value
+ // equals (llvm::Triple::Pixel + N).
+ //
// For example, if a function has been scanned in compute and pixel stage
- // environment, the value will be 0x21 (100001 binary) because
- // (int)HLSLShaderAttr::ShaderType::Pixel == 1 and
- // (int)HLSLShaderAttr::ShaderType::Compute == 5.
+ // environment, the value will be 0x21 (100001 binary) because:
+ //
+ // (int)(llvm::Triple::Pixel - llvm::Triple::Pixel) == 0
+ // (int)(llvm::Triple::Compute - llvm::Triple::Pixel) == 5
+ //
// A FunctionDecl is mapped to 0 (or not included in the map) if it has not
// been scanned in any environment.
llvm::DenseMap<const FunctionDecl *, unsigned> ScannedDecls;
@@ -346,12 +353,16 @@ class DiagnoseHLSLAvailability
bool ReportOnlyShaderStageIssues;
// Helper methods for dealing with current stage context / environment
- void SetShaderStageContext(HLSLShaderAttr::ShaderType ShaderType) {
+ void SetShaderStageContext(llvm::Triple::EnvironmentType ShaderType) {
static_assert(sizeof(unsigned) >= 4);
- assert((unsigned)ShaderType < 31); // 31 is reserved for "unknown"
-
- CurrentShaderEnvironment = HLSLShaderAttr::getTypeAsEnvironment(ShaderType);
- CurrentShaderStageBit = (1 << ShaderType);
+ assert(HLSLShaderAttr::isValidShaderType(ShaderType));
+ assert((unsigned)(ShaderType - llvm::Triple::Pixel) < 31 &&
+ "ShaderType is too big for this bitmap"); // 31 is reserved for
+ // "unknown"
+
+ unsigned bitmapIndex = ShaderType - llvm::Triple::Pixel;
+ CurrentShaderEnvironment = ShaderType;
+ CurrentShaderStageBit = (1 << bitmapIndex);
}
void SetUnknownShaderStageContext() {
``````````
</details>
https://github.com/llvm/llvm-project/pull/93847
More information about the cfe-commits
mailing list