[clang] 5d87ba1 - [HLSL] Use llvm::Triple::EnvironmentType instead of HLSLShaderAttr::ShaderType (#93847)
via cfe-commits
cfe-commits at lists.llvm.org
Fri Jun 7 21:30:09 PDT 2024
Author: Helena Kotas
Date: 2024-06-07T21:30:04-07:00
New Revision: 5d87ba1c1f584dfbd5afaf187099b43681b2206d
URL: https://github.com/llvm/llvm-project/commit/5d87ba1c1f584dfbd5afaf187099b43681b2206d
DIFF: https://github.com/llvm/llvm-project/commit/5d87ba1c1f584dfbd5afaf187099b43681b2206d.diff
LOG: [HLSL] Use llvm::Triple::EnvironmentType instead of HLSLShaderAttr::ShaderType (#93847)
`HLSLShaderAttr::ShaderType` enum is a subset of
`llvm::Triple::EnvironmentType`. We can use
`llvm::Triple::EnvironmentType` directly and avoid converting one enum
to another.
Added:
Modified:
clang/include/clang/Basic/Attr.td
clang/include/clang/Sema/SemaHLSL.h
clang/lib/CodeGen/CGHLSLRuntime.cpp
clang/lib/Sema/SemaHLSL.cpp
Removed:
################################################################################
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index 17d9a710d948b..b70b0c8b836a5 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4470,37 +4470,20 @@ def HLSLShader : InheritableAttr {
let Subjects = SubjectList<[HLSLEntry]>;
let LangOpts = [HLSL];
let Args = [
- EnumArgument<"Type", "ShaderType", /*is_string=*/true,
+ EnumArgument<"Type", "llvm::Triple::EnvironmentType", /*is_string=*/true,
["pixel", "vertex", "geometry", "hull", "domain", "compute",
"raygeneration", "intersection", "anyhit", "closesthit",
"miss", "callable", "mesh", "amplification"],
["Pixel", "Vertex", "Geometry", "Hull", "Domain", "Compute",
"RayGeneration", "Intersection", "AnyHit", "ClosestHit",
- "Miss", "Callable", "Mesh", "Amplification"]>
+ "Miss", "Callable", "Mesh", "Amplification"],
+ /*opt=*/0, /*fake=*/0, /*isExternalType=*/1>
];
let Documentation = [HLSLSV_ShaderTypeAttrDocs];
let AdditionalMembers =
[{
- static const unsigned ShaderTypeMaxValue = (unsigned)HLSLShaderAttr::Amplification;
-
- static llvm::Triple::EnvironmentType getTypeAsEnvironment(HLSLShaderAttr::ShaderType ShaderType) {
- switch (ShaderType) {
- case HLSLShaderAttr::Pixel: return llvm::Triple::Pixel;
- case HLSLShaderAttr::Vertex: return llvm::Triple::Vertex;
- case HLSLShaderAttr::Geometry: return llvm::Triple::Geometry;
- case HLSLShaderAttr::Hull: return llvm::Triple::Hull;
- case HLSLShaderAttr::Domain: return llvm::Triple::Domain;
- case HLSLShaderAttr::Compute: return llvm::Triple::Compute;
- case HLSLShaderAttr::RayGeneration: return llvm::Triple::RayGeneration;
- case HLSLShaderAttr::Intersection: return llvm::Triple::Intersection;
- case HLSLShaderAttr::AnyHit: return llvm::Triple::AnyHit;
- case HLSLShaderAttr::ClosestHit: return llvm::Triple::ClosestHit;
- case HLSLShaderAttr::Miss: return llvm::Triple::Miss;
- case HLSLShaderAttr::Callable: return llvm::Triple::Callable;
- case HLSLShaderAttr::Mesh: return llvm::Triple::Mesh;
- case HLSLShaderAttr::Amplification: return llvm::Triple::Amplification;
- }
- llvm_unreachable("unknown enumeration value");
+ static bool isValidShaderType(llvm::Triple::EnvironmentType ShaderType) {
+ return ShaderType >= llvm::Triple::Pixel && ShaderType <= llvm::Triple::Amplification;
}
}];
}
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index e145f5e7f43f8..0e41a72e444ef 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -39,7 +39,7 @@ class SemaHLSL : public SemaBase {
const AttributeCommonInfo &AL, int X,
int Y, int Z);
HLSLShaderAttr *mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
- HLSLShaderAttr::ShaderType ShaderType);
+ llvm::Triple::EnvironmentType ShaderType);
HLSLParamModifierAttr *
mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
HLSLParamModifierAttr::Spelling Spelling);
@@ -48,8 +48,8 @@ class SemaHLSL : public SemaBase {
void CheckSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param,
const HLSLAnnotationAttr *AnnotationAttr);
void DiagnoseAttrStageMismatch(
- const Attr *A, HLSLShaderAttr::ShaderType Stage,
- std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages);
+ const Attr *A, llvm::Triple::EnvironmentType Stage,
+ std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages);
void DiagnoseAvailabilityViolations(TranslationUnitDecl *TU);
void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL);
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp
index 5e6a3dd4878f4..55ba21ae2ba69 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.cpp
+++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp
@@ -313,7 +313,7 @@ void clang::CodeGen::CGHLSLRuntime::setHLSLEntryAttributes(
assert(ShaderAttr && "All entry functions must have a HLSLShaderAttr");
const StringRef ShaderAttrKindStr = "hlsl.shader";
Fn->addFnAttr(ShaderAttrKindStr,
- ShaderAttr->ConvertShaderTypeToStr(ShaderAttr->getType()));
+ llvm::Triple::getEnvironmentTypeName(ShaderAttr->getType()));
if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr<HLSLNumThreadsAttr>()) {
const StringRef NumThreadsKindStr = "hlsl.numthreads";
std::string NumThreadsStr =
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 0a2face7afe65..144cdcc0d98ef 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -146,7 +146,7 @@ HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D,
HLSLShaderAttr *
SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
- HLSLShaderAttr::ShaderType ShaderType) {
+ llvm::Triple::EnvironmentType ShaderType) {
if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) {
if (NT->getType() != ShaderType) {
Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
@@ -184,13 +184,12 @@ void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {
if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry)
return;
- StringRef Env = TargetInfo.getTriple().getEnvironmentName();
- HLSLShaderAttr::ShaderType ShaderType;
- if (HLSLShaderAttr::ConvertStrToShaderType(Env, ShaderType)) {
+ llvm::Triple::EnvironmentType Env = TargetInfo.getTriple().getEnvironment();
+ if (HLSLShaderAttr::isValidShaderType(Env) && Env != llvm::Triple::Library) {
if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) {
// The entry point is already annotated - check that it matches the
// triple.
- if (Shader->getType() != ShaderType) {
+ if (Shader->getType() != Env) {
Diag(Shader->getLocation(), diag::err_hlsl_entry_shader_attr_mismatch)
<< Shader;
FD->setInvalidDecl();
@@ -198,11 +197,11 @@ void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {
} else {
// Implicitly add the shader attribute if the entry function isn't
// explicitly annotated.
- FD->addAttr(HLSLShaderAttr::CreateImplicit(getASTContext(), ShaderType,
+ FD->addAttr(HLSLShaderAttr::CreateImplicit(getASTContext(), Env,
FD->getBeginLoc()));
}
} else {
- switch (TargetInfo.getTriple().getEnvironment()) {
+ switch (Env) {
case llvm::Triple::UnknownEnvironment:
case llvm::Triple::Library:
break;
@@ -215,38 +214,40 @@ void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {
void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
assert(ShaderAttr && "Entry point has no shader attribute");
- HLSLShaderAttr::ShaderType ST = ShaderAttr->getType();
+ llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
switch (ST) {
- case HLSLShaderAttr::Pixel:
- case HLSLShaderAttr::Vertex:
- case HLSLShaderAttr::Geometry:
- case HLSLShaderAttr::Hull:
- case HLSLShaderAttr::Domain:
- case HLSLShaderAttr::RayGeneration:
- case HLSLShaderAttr::Intersection:
- case HLSLShaderAttr::AnyHit:
- case HLSLShaderAttr::ClosestHit:
- case HLSLShaderAttr::Miss:
- case HLSLShaderAttr::Callable:
+ case llvm::Triple::Pixel:
+ case llvm::Triple::Vertex:
+ case llvm::Triple::Geometry:
+ case llvm::Triple::Hull:
+ case llvm::Triple::Domain:
+ case llvm::Triple::RayGeneration:
+ case llvm::Triple::Intersection:
+ case llvm::Triple::AnyHit:
+ case llvm::Triple::ClosestHit:
+ case llvm::Triple::Miss:
+ case llvm::Triple::Callable:
if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) {
DiagnoseAttrStageMismatch(NT, ST,
- {HLSLShaderAttr::Compute,
- HLSLShaderAttr::Amplification,
- HLSLShaderAttr::Mesh});
+ {llvm::Triple::Compute,
+ llvm::Triple::Amplification,
+ llvm::Triple::Mesh});
FD->setInvalidDecl();
}
break;
- case HLSLShaderAttr::Compute:
- case HLSLShaderAttr::Amplification:
- case HLSLShaderAttr::Mesh:
+ case llvm::Triple::Compute:
+ case llvm::Triple::Amplification:
+ case llvm::Triple::Mesh:
if (!FD->hasAttr<HLSLNumThreadsAttr>()) {
Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads)
- << HLSLShaderAttr::ConvertShaderTypeToStr(ST);
+ << llvm::Triple::getEnvironmentTypeName(ST);
FD->setInvalidDecl();
}
break;
+ default:
+ llvm_unreachable("Unhandled environment in triple");
}
for (ParmVarDecl *Param : FD->parameters()) {
@@ -268,14 +269,14 @@ void SemaHLSL::CheckSemanticAnnotation(
const HLSLAnnotationAttr *AnnotationAttr) {
auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>();
assert(ShaderAttr && "Entry point has no shader attribute");
- HLSLShaderAttr::ShaderType ST = ShaderAttr->getType();
+ llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
switch (AnnotationAttr->getKind()) {
case attr::HLSLSV_DispatchThreadID:
case attr::HLSLSV_GroupIndex:
- if (ST == HLSLShaderAttr::Compute)
+ if (ST == llvm::Triple::Compute)
return;
- DiagnoseAttrStageMismatch(AnnotationAttr, ST, {HLSLShaderAttr::Compute});
+ DiagnoseAttrStageMismatch(AnnotationAttr, ST, {llvm::Triple::Compute});
break;
default:
llvm_unreachable("Unknown HLSLAnnotationAttr");
@@ -283,16 +284,16 @@ void SemaHLSL::CheckSemanticAnnotation(
}
void SemaHLSL::DiagnoseAttrStageMismatch(
- const Attr *A, HLSLShaderAttr::ShaderType Stage,
- std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages) {
+ const Attr *A, llvm::Triple::EnvironmentType Stage,
+ std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages) {
SmallVector<StringRef, 8> StageStrings;
llvm::transform(AllowedStages, std::back_inserter(StageStrings),
- [](HLSLShaderAttr::ShaderType ST) {
+ [](llvm::Triple::EnvironmentType ST) {
return StringRef(
- HLSLShaderAttr::ConvertShaderTypeToStr(ST));
+ HLSLShaderAttr::ConvertEnvironmentTypeToStr(ST));
});
Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
- << A << HLSLShaderAttr::ConvertShaderTypeToStr(Stage)
+ << A << llvm::Triple::getEnvironmentTypeName(Stage)
<< (AllowedStages.size() != 1) << join(StageStrings, ", ");
}
@@ -430,8 +431,8 @@ void SemaHLSL::handleShaderAttr(Decl *D, const ParsedAttr &AL) {
if (!SemaRef.checkStringLiteralArgumentAttr(AL, 0, Str, &ArgLoc))
return;
- HLSLShaderAttr::ShaderType ShaderType;
- if (!HLSLShaderAttr::ConvertStrToShaderType(Str, ShaderType)) {
+ llvm::Triple::EnvironmentType ShaderType;
+ if (!HLSLShaderAttr::ConvertStrToEnvironmentType(Str, ShaderType)) {
Diag(AL.getLoc(), diag::warn_attribute_type_not_supported)
<< AL << Str << ArgLoc;
return;
@@ -549,16 +550,22 @@ class DiagnoseHLSLAvailability
//
// Maps FunctionDecl to an unsigned number that represents the set of shader
// environments the function has been scanned for.
- // Since HLSLShaderAttr::ShaderType enum is generated from Attr.td and is
- // defined without any assigned values, it is guaranteed to be numbered
- // sequentially from 0 up and we can use it to 'index' individual bits
- // in the set.
+ // The llvm::Triple::EnvironmentType enum values for shader stages guaranteed
+ // to be numbered from llvm::Triple::Pixel to llvm::Triple::Amplification
+ // (verified by static_asserts in Triple.cpp), we can use it to index
+ // individual bits in the set, as long as we shift the values to start with 0
+ // by subtracting the value of llvm::Triple::Pixel first.
+ //
// The N'th bit in the set will be set if the function has been scanned
- // in shader environment whose ShaderType integer value equals N.
+ // in shader environment whose llvm::Triple::EnvironmentType integer value
+ // equals (llvm::Triple::Pixel + N).
+ //
// For example, if a function has been scanned in compute and pixel stage
- // environment, the value will be 0x21 (100001 binary) because
- // (int)HLSLShaderAttr::ShaderType::Pixel == 1 and
- // (int)HLSLShaderAttr::ShaderType::Compute == 5.
+ // environment, the value will be 0x21 (100001 binary) because:
+ //
+ // (int)(llvm::Triple::Pixel - llvm::Triple::Pixel) == 0
+ // (int)(llvm::Triple::Compute - llvm::Triple::Pixel) == 5
+ //
// A FunctionDecl is mapped to 0 (or not included in the map) if it has not
// been scanned in any environment.
llvm::DenseMap<const FunctionDecl *, unsigned> ScannedDecls;
@@ -574,12 +581,16 @@ class DiagnoseHLSLAvailability
bool ReportOnlyShaderStageIssues;
// Helper methods for dealing with current stage context / environment
- void SetShaderStageContext(HLSLShaderAttr::ShaderType ShaderType) {
+ void SetShaderStageContext(llvm::Triple::EnvironmentType ShaderType) {
static_assert(sizeof(unsigned) >= 4);
- assert((unsigned)ShaderType < 31); // 31 is reserved for "unknown"
-
- CurrentShaderEnvironment = HLSLShaderAttr::getTypeAsEnvironment(ShaderType);
- CurrentShaderStageBit = (1 << ShaderType);
+ assert(HLSLShaderAttr::isValidShaderType(ShaderType));
+ assert((unsigned)(ShaderType - llvm::Triple::Pixel) < 31 &&
+ "ShaderType is too big for this bitmap"); // 31 is reserved for
+ // "unknown"
+
+ unsigned bitmapIndex = ShaderType - llvm::Triple::Pixel;
+ CurrentShaderEnvironment = ShaderType;
+ CurrentShaderStageBit = (1 << bitmapIndex);
}
void SetUnknownShaderStageContext() {
More information about the cfe-commits
mailing list