[llvm] 6312d27 - [DirectX] Emit `hlsl.wavesize` function attribute as entry property metadata (#165624)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Nov 5 09:18:54 PST 2025
Author: Finn Plummer
Date: 2025-11-05T09:18:49-08:00
New Revision: 6312d2751144bd53af7ef56798cbe60aa8b2fb56
URL: https://github.com/llvm/llvm-project/commit/6312d2751144bd53af7ef56798cbe60aa8b2fb56
DIFF: https://github.com/llvm/llvm-project/commit/6312d2751144bd53af7ef56798cbe60aa8b2fb56.diff
LOG: [DirectX] Emit `hlsl.wavesize` function attribute as entry property metadata (#165624)
This pr adds support for emitting the `hlsl.wavesize` function attribute
as an entry property metadata for a compute shader.
It follows the implementation of `hlsl.numthreads`.
- Collects the wave range information from the function attribute in
`DXILMetadataAnalysis`
- Introduce the `WaveRange` property tag
- Emit a `WaveSize` or `WaveRange` metadata (depending on shader model)
in `DXILTranslateMetadata`
- Add tests for valid/invalid scenarios
- Updates the base `PSVInfo` to reflect the min/max wave lane counts
Resolves #70118
Added:
llvm/test/CodeGen/DirectX/wavesize-md-errs.ll
llvm/test/CodeGen/DirectX/wavesize-md-valid.ll
Modified:
llvm/include/llvm/Analysis/DXILMetadataAnalysis.h
llvm/lib/Analysis/DXILMetadataAnalysis.cpp
llvm/lib/Target/DirectX/DXContainerGlobals.cpp
llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h b/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h
index cb535ac14f1c6..a1b030c157eae 100644
--- a/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h
+++ b/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h
@@ -27,6 +27,9 @@ struct EntryProperties {
unsigned NumThreadsX{0}; // X component
unsigned NumThreadsY{0}; // Y component
unsigned NumThreadsZ{0}; // Z component
+ unsigned WaveSizeMin{0}; // Minimum component
+ unsigned WaveSizeMax{0}; // Maximum component
+ unsigned WaveSizePref{0}; // Preferred component
EntryProperties(const Function *Fn = nullptr) : Entry(Fn) {};
};
diff --git a/llvm/lib/Analysis/DXILMetadataAnalysis.cpp b/llvm/lib/Analysis/DXILMetadataAnalysis.cpp
index 23f1aa82ae8a3..bd77cba385667 100644
--- a/llvm/lib/Analysis/DXILMetadataAnalysis.cpp
+++ b/llvm/lib/Analysis/DXILMetadataAnalysis.cpp
@@ -66,6 +66,22 @@ static ModuleMetadataInfo collectMetadataInfo(Module &M) {
Success = llvm::to_integer(NumThreadsVec[2], EFP.NumThreadsZ, 10);
assert(Success && "Failed to parse Z component of numthreads");
}
+ // Get wavesize attribute value, if one exists
+ StringRef WaveSizeStr =
+ F.getFnAttribute("hlsl.wavesize").getValueAsString();
+ if (!WaveSizeStr.empty()) {
+ SmallVector<StringRef> WaveSizeVec;
+ WaveSizeStr.split(WaveSizeVec, ',');
+ assert(WaveSizeVec.size() == 3 && "Invalid wavesize specified");
+ // Read in the three component values of numthreads
+ [[maybe_unused]] bool Success =
+ llvm::to_integer(WaveSizeVec[0], EFP.WaveSizeMin, 10);
+ assert(Success && "Failed to parse Min component of wavesize");
+ Success = llvm::to_integer(WaveSizeVec[1], EFP.WaveSizeMax, 10);
+ assert(Success && "Failed to parse Max component of wavesize");
+ Success = llvm::to_integer(WaveSizeVec[2], EFP.WaveSizePref, 10);
+ assert(Success && "Failed to parse Preferred component of wavesize");
+ }
MMDAI.EntryPropertyVec.push_back(EFP);
}
return MMDAI;
diff --git a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
index eb4c8846441a2..677203d1c016b 100644
--- a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
+++ b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
@@ -285,6 +285,13 @@ void DXContainerGlobals::addPipelineStateValidationInfo(
PSV.BaseData.NumThreadsX = MMI.EntryPropertyVec[0].NumThreadsX;
PSV.BaseData.NumThreadsY = MMI.EntryPropertyVec[0].NumThreadsY;
PSV.BaseData.NumThreadsZ = MMI.EntryPropertyVec[0].NumThreadsZ;
+ if (MMI.EntryPropertyVec[0].WaveSizeMin) {
+ PSV.BaseData.MinimumWaveLaneCount = MMI.EntryPropertyVec[0].WaveSizeMin;
+ PSV.BaseData.MaximumWaveLaneCount =
+ MMI.EntryPropertyVec[0].WaveSizeMax
+ ? MMI.EntryPropertyVec[0].WaveSizeMax
+ : MMI.EntryPropertyVec[0].WaveSizeMin;
+ }
break;
default:
break;
diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
index cf8b833b3e42e..e1a472fe57642 100644
--- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
+++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
@@ -82,6 +82,7 @@ enum class EntryPropsTag {
ASStateTag,
WaveSize,
EntryRootSig,
+ WaveRange = 23,
};
} // namespace
@@ -177,14 +178,15 @@ getTagValueAsMetadata(EntryPropsTag Tag, uint64_t Value, LLVMContext &Ctx) {
case EntryPropsTag::ASStateTag:
case EntryPropsTag::WaveSize:
case EntryPropsTag::EntryRootSig:
+ case EntryPropsTag::WaveRange:
llvm_unreachable("NYI: Unhandled entry property tag");
}
return MDVals;
}
-static MDTuple *
-getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags,
- const Triple::EnvironmentType ShaderProfile) {
+static MDTuple *getEntryPropAsMetadata(Module &M, const EntryProperties &EP,
+ uint64_t EntryShaderFlags,
+ const ModuleMetadataInfo &MMDI) {
SmallVector<Metadata *> MDVals;
LLVMContext &Ctx = EP.Entry->getContext();
if (EntryShaderFlags != 0)
@@ -195,12 +197,13 @@ getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags,
// FIXME: support more props.
// See https://github.com/llvm/llvm-project/issues/57948.
// Add shader kind for lib entries.
- if (ShaderProfile == Triple::EnvironmentType::Library &&
+ if (MMDI.ShaderProfile == Triple::EnvironmentType::Library &&
EP.ShaderStage != Triple::EnvironmentType::Library)
MDVals.append(getTagValueAsMetadata(EntryPropsTag::ShaderKind,
getShaderStage(EP.ShaderStage), Ctx));
if (EP.ShaderStage == Triple::EnvironmentType::Compute) {
+ // Handle mandatory "hlsl.numthreads"
MDVals.emplace_back(ConstantAsMetadata::get(ConstantInt::get(
Type::getInt32Ty(Ctx), static_cast<int>(EntryPropsTag::NumThreads))));
Metadata *NumThreadVals[] = {ConstantAsMetadata::get(ConstantInt::get(
@@ -210,8 +213,48 @@ getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags,
ConstantAsMetadata::get(ConstantInt::get(
Type::getInt32Ty(Ctx), EP.NumThreadsZ))};
MDVals.emplace_back(MDNode::get(Ctx, NumThreadVals));
+
+ // Handle optional "hlsl.wavesize". The fields are optionally represented
+ // if they are non-zero.
+ if (EP.WaveSizeMin != 0) {
+ bool IsWaveRange = VersionTuple(6, 8) <= MMDI.ShaderModelVersion;
+ bool IsWaveSize =
+ !IsWaveRange && VersionTuple(6, 6) <= MMDI.ShaderModelVersion;
+
+ if (!IsWaveRange && !IsWaveSize) {
+ reportError(M, "Shader model 6.6 or greater is required to specify "
+ "the \"hlsl.wavesize\" function attribute");
+ return nullptr;
+ }
+
+ // A range is being specified if EP.WaveSizeMax != 0
+ if (EP.WaveSizeMax && !IsWaveRange) {
+ reportError(
+ M, "Shader model 6.8 or greater is required to specify "
+ "wave size range values of the \"hlsl.wavesize\" function "
+ "attribute");
+ return nullptr;
+ }
+
+ EntryPropsTag Tag =
+ IsWaveSize ? EntryPropsTag::WaveSize : EntryPropsTag::WaveRange;
+ MDVals.emplace_back(ConstantAsMetadata::get(
+ ConstantInt::get(Type::getInt32Ty(Ctx), static_cast<int>(Tag))));
+
+ SmallVector<Metadata *> WaveSizeVals = {ConstantAsMetadata::get(
+ ConstantInt::get(Type::getInt32Ty(Ctx), EP.WaveSizeMin))};
+ if (IsWaveRange) {
+ WaveSizeVals.push_back(ConstantAsMetadata::get(
+ ConstantInt::get(Type::getInt32Ty(Ctx), EP.WaveSizeMax)));
+ WaveSizeVals.push_back(ConstantAsMetadata::get(
+ ConstantInt::get(Type::getInt32Ty(Ctx), EP.WaveSizePref)));
+ }
+
+ MDVals.emplace_back(MDNode::get(Ctx, WaveSizeVals));
+ }
}
}
+
if (MDVals.empty())
return nullptr;
return MDNode::get(Ctx, MDVals);
@@ -236,12 +279,11 @@ static MDTuple *constructEntryMetadata(const Function *EntryFn,
return MDNode::get(Ctx, MDVals);
}
-static MDTuple *emitEntryMD(const EntryProperties &EP, MDTuple *Signatures,
- MDNode *MDResources,
+static MDTuple *emitEntryMD(Module &M, const EntryProperties &EP,
+ MDTuple *Signatures, MDNode *MDResources,
const uint64_t EntryShaderFlags,
- const Triple::EnvironmentType ShaderProfile) {
- MDTuple *Properties =
- getEntryPropAsMetadata(EP, EntryShaderFlags, ShaderProfile);
+ const ModuleMetadataInfo &MMDI) {
+ MDTuple *Properties = getEntryPropAsMetadata(M, EP, EntryShaderFlags, MMDI);
return constructEntryMetadata(EP.Entry, Signatures, MDResources, Properties,
EP.Entry->getContext());
}
@@ -523,10 +565,8 @@ static void translateGlobalMetadata(Module &M, DXILResourceMap &DRM,
Twine(Triple::getEnvironmentTypeName(MMDI.ShaderProfile) +
"'"));
}
-
- EntryFnMDNodes.emplace_back(emitEntryMD(EntryProp, Signatures, ResourceMD,
- EntryShaderFlags,
- MMDI.ShaderProfile));
+ EntryFnMDNodes.emplace_back(emitEntryMD(
+ M, EntryProp, Signatures, ResourceMD, EntryShaderFlags, MMDI));
}
NamedMDNode *EntryPointsNamedMD =
diff --git a/llvm/test/CodeGen/DirectX/wavesize-md-errs.ll b/llvm/test/CodeGen/DirectX/wavesize-md-errs.ll
new file mode 100644
index 0000000000000..9016c5d7e8d44
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/wavesize-md-errs.ll
@@ -0,0 +1,31 @@
+; RUN: split-file %s %t
+; RUN: not opt -S --dxil-translate-metadata %t/low-sm.ll 2>&1 | FileCheck %t/low-sm.ll
+; RUN: not opt -S --dxil-translate-metadata %t/low-sm-for-range.ll 2>&1 | FileCheck %t/low-sm-for-range.ll
+
+; Test that wavesize metadata is only allowed on applicable shader model versions
+
+;--- low-sm.ll
+
+; CHECK: Shader model 6.6 or greater is required to specify the "hlsl.wavesize" function attribute
+
+target triple = "dxil-unknown-shadermodel6.5-compute"
+
+define void @main() #0 {
+entry:
+ ret void
+}
+
+attributes #0 = { "hlsl.wavesize"="16,0,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
+
+;--- low-sm-for-range.ll
+
+; CHECK: Shader model 6.8 or greater is required to specify wave size range values of the "hlsl.wavesize" function attribute
+
+target triple = "dxil-unknown-shadermodel6.7-compute"
+
+define void @main() #0 {
+entry:
+ ret void
+}
+
+attributes #0 = { "hlsl.wavesize"="16,32,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
diff --git a/llvm/test/CodeGen/DirectX/wavesize-md-valid.ll b/llvm/test/CodeGen/DirectX/wavesize-md-valid.ll
new file mode 100644
index 0000000000000..3ad6c1d034252
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/wavesize-md-valid.ll
@@ -0,0 +1,96 @@
+; RUN: split-file %s %t
+; RUN: opt -S --dxil-translate-metadata %t/only.ll | FileCheck %t/only.ll
+; RUN: opt -S --dxil-translate-metadata %t/min.ll | FileCheck %t/min.ll
+; RUN: opt -S --dxil-translate-metadata %t/max.ll | FileCheck %t/max.ll
+; RUN: opt -S --dxil-translate-metadata %t/pref.ll | FileCheck %t/pref.ll
+
+; RUN: llc --filetype=obj %t/only.ll -o - | obj2yaml | FileCheck %t/only.ll --check-prefix=OBJ
+; RUN: llc --filetype=obj %t/min.ll -o - | obj2yaml | FileCheck %t/min.ll --check-prefix=OBJ
+; RUN: llc --filetype=obj %t/max.ll -o - | obj2yaml | FileCheck %t/max.ll --check-prefix=OBJ
+; RUN: llc --filetype=obj %t/pref.ll -o - | obj2yaml | FileCheck %t/pref.ll --check-prefix=OBJ
+
+; Test that wave size/range metadata is correctly generated with the correct tag
+
+;--- only.ll
+
+; CHECK: !dx.entryPoints = !{![[#ENTRY:]]}
+; CHECK: ![[#ENTRY]] = !{ptr @main, !"main", null, null, ![[#PROPS:]]}
+; CHECK: ![[#PROPS]] = !{{{.*}}i32 11, ![[#WAVE_SIZE:]]{{.*}}}
+; CHECK: ![[#WAVE_SIZE]] = !{i32 16}
+
+; OBJ: - Name: PSV0
+; OBJ: PSVInfo:
+; OBJ: MinimumWaveLaneCount: 16
+; OBJ: MaximumWaveLaneCount: 16
+
+target triple = "dxil-unknown-shadermodel6.6-compute"
+
+define void @main() #0 {
+entry:
+ ret void
+}
+
+attributes #0 = { "hlsl.wavesize"="16,0,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
+
+;--- min.ll
+
+; CHECK: !dx.entryPoints = !{![[#ENTRY:]]}
+; CHECK: ![[#ENTRY]] = !{ptr @main, !"main", null, null, ![[#PROPS:]]}
+; CHECK: ![[#PROPS]] = !{{{.*}}i32 23, ![[#WAVE_SIZE:]]{{.*}}}
+; CHECK: ![[#WAVE_SIZE]] = !{i32 16, i32 0, i32 0}
+
+; OBJ: - Name: PSV0
+; OBJ: PSVInfo:
+; OBJ: MinimumWaveLaneCount: 16
+; OBJ: MaximumWaveLaneCount: 16
+
+target triple = "dxil-unknown-shadermodel6.8-compute"
+
+define void @main() #0 {
+entry:
+ ret void
+}
+
+attributes #0 = { "hlsl.wavesize"="16,0,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
+
+;--- max.ll
+
+; CHECK: !dx.entryPoints = !{![[#ENTRY:]]}
+; CHECK: ![[#ENTRY]] = !{ptr @main, !"main", null, null, ![[#PROPS:]]}
+; CHECK: ![[#PROPS]] = !{{{.*}}i32 23, ![[#WAVE_SIZE:]]{{.*}}}
+; CHECK: ![[#WAVE_SIZE]] = !{i32 16, i32 32, i32 0}
+
+; OBJ: - Name: PSV0
+; OBJ: PSVInfo:
+; OBJ: MinimumWaveLaneCount: 16
+; OBJ: MaximumWaveLaneCount: 32
+
+target triple = "dxil-unknown-shadermodel6.8-compute"
+
+define void @main() #0 {
+entry:
+ ret void
+}
+
+attributes #0 = { "hlsl.wavesize"="16,32,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
+
+;--- pref.ll
+
+; CHECK: !dx.entryPoints = !{![[#ENTRY:]]}
+; CHECK: ![[#ENTRY]] = !{ptr @main, !"main", null, null, ![[#PROPS:]]}
+; CHECK: ![[#PROPS]] = !{{{.*}}i32 23, ![[#WAVE_SIZE:]]{{.*}}}
+; CHECK: ![[#WAVE_SIZE]] = !{i32 16, i32 64, i32 32}
+
+; OBJ: - Name: PSV0
+; OBJ: PSVInfo:
+; OBJ: MinimumWaveLaneCount: 16
+; OBJ: MaximumWaveLaneCount: 64
+
+target triple = "dxil-unknown-shadermodel6.8-compute"
+
+define void @main() #0 {
+entry:
+ ret void
+}
+
+attributes #0 = { "hlsl.wavesize"="16,64,32" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
More information about the llvm-commits
mailing list