[llvm] [DirectX] Emit `WaveSize` function attribute metadata (PR #165624)
Finn Plummer via llvm-commits
llvm-commits at lists.llvm.org
Wed Oct 29 13:46:44 PDT 2025
https://github.com/inbelic created https://github.com/llvm/llvm-project/pull/165624
Resolves #70118
>From 793019a5052f27d770f9874f3cea311c9968339d Mon Sep 17 00:00:00 2001
From: Finn Plummer <mail at inbelic.dev>
Date: Wed, 29 Oct 2025 13:45:14 -0700
Subject: [PATCH] [DirectX] Emit `WaveSize` function attribute metadata
---
.../llvm/Analysis/DXILMetadataAnalysis.h | 3 +
llvm/lib/Analysis/DXILMetadataAnalysis.cpp | 16 +++++
.../Target/DirectX/DXILTranslateMetadata.cpp | 67 +++++++++++++----
llvm/test/CodeGen/DirectX/wavesize-md-errs.ll | 31 ++++++++
.../test/CodeGen/DirectX/wavesize-md-valid.ll | 71 +++++++++++++++++++
5 files changed, 174 insertions(+), 14 deletions(-)
create mode 100644 llvm/test/CodeGen/DirectX/wavesize-md-errs.ll
create mode 100644 llvm/test/CodeGen/DirectX/wavesize-md-valid.ll
diff --git a/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h b/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h
index cb535ac14f1c6..a1b030c157eae 100644
--- a/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h
+++ b/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h
@@ -27,6 +27,9 @@ struct EntryProperties {
unsigned NumThreadsX{0}; // X component
unsigned NumThreadsY{0}; // Y component
unsigned NumThreadsZ{0}; // Z component
+ unsigned WaveSizeMin{0}; // Minimum component
+ unsigned WaveSizeMax{0}; // Maximum component
+ unsigned WaveSizePref{0}; // Preferred component
EntryProperties(const Function *Fn = nullptr) : Entry(Fn) {};
};
diff --git a/llvm/lib/Analysis/DXILMetadataAnalysis.cpp b/llvm/lib/Analysis/DXILMetadataAnalysis.cpp
index 23f1aa82ae8a3..bd77cba385667 100644
--- a/llvm/lib/Analysis/DXILMetadataAnalysis.cpp
+++ b/llvm/lib/Analysis/DXILMetadataAnalysis.cpp
@@ -66,6 +66,22 @@ static ModuleMetadataInfo collectMetadataInfo(Module &M) {
Success = llvm::to_integer(NumThreadsVec[2], EFP.NumThreadsZ, 10);
assert(Success && "Failed to parse Z component of numthreads");
}
+ // Get wavesize attribute value, if one exists
+ StringRef WaveSizeStr =
+ F.getFnAttribute("hlsl.wavesize").getValueAsString();
+ if (!WaveSizeStr.empty()) {
+ SmallVector<StringRef> WaveSizeVec;
+ WaveSizeStr.split(WaveSizeVec, ',');
+ assert(WaveSizeVec.size() == 3 && "Invalid wavesize specified");
+ // Read in the three component values of numthreads
+ [[maybe_unused]] bool Success =
+ llvm::to_integer(WaveSizeVec[0], EFP.WaveSizeMin, 10);
+ assert(Success && "Failed to parse Min component of wavesize");
+ Success = llvm::to_integer(WaveSizeVec[1], EFP.WaveSizeMax, 10);
+ assert(Success && "Failed to parse Max component of wavesize");
+ Success = llvm::to_integer(WaveSizeVec[2], EFP.WaveSizePref, 10);
+ assert(Success && "Failed to parse Preferred component of wavesize");
+ }
MMDAI.EntryPropertyVec.push_back(EFP);
}
return MMDAI;
diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
index cf8b833b3e42e..682847a94c6fb 100644
--- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
+++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
@@ -82,6 +82,7 @@ enum class EntryPropsTag {
ASStateTag,
WaveSize,
EntryRootSig,
+ WaveRange = 23,
};
} // namespace
@@ -177,30 +178,32 @@ getTagValueAsMetadata(EntryPropsTag Tag, uint64_t Value, LLVMContext &Ctx) {
case EntryPropsTag::ASStateTag:
case EntryPropsTag::WaveSize:
case EntryPropsTag::EntryRootSig:
+ case EntryPropsTag::WaveRange:
llvm_unreachable("NYI: Unhandled entry property tag");
}
return MDVals;
}
-static MDTuple *
-getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags,
- const Triple::EnvironmentType ShaderProfile) {
+static MDTuple *getEntryPropAsMetadata(Module &M, const EntryProperties &EP,
+ uint64_t EntryShaderFlags,
+ const ModuleMetadataInfo &MMDI) {
SmallVector<Metadata *> MDVals;
LLVMContext &Ctx = EP.Entry->getContext();
if (EntryShaderFlags != 0)
MDVals.append(getTagValueAsMetadata(EntryPropsTag::ShaderFlags,
- EntryShaderFlags, Ctx));
+ MMDI.ShaderProfile, Ctx));
if (EP.Entry != nullptr) {
// FIXME: support more props.
// See https://github.com/llvm/llvm-project/issues/57948.
// Add shader kind for lib entries.
- if (ShaderProfile == Triple::EnvironmentType::Library &&
+ if (MMDI.ShaderProfile == Triple::EnvironmentType::Library &&
EP.ShaderStage != Triple::EnvironmentType::Library)
MDVals.append(getTagValueAsMetadata(EntryPropsTag::ShaderKind,
getShaderStage(EP.ShaderStage), Ctx));
if (EP.ShaderStage == Triple::EnvironmentType::Compute) {
+ // Handle mandatory "hlsl.numthreads"
MDVals.emplace_back(ConstantAsMetadata::get(ConstantInt::get(
Type::getInt32Ty(Ctx), static_cast<int>(EntryPropsTag::NumThreads))));
Metadata *NumThreadVals[] = {ConstantAsMetadata::get(ConstantInt::get(
@@ -210,8 +213,47 @@ getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags,
ConstantAsMetadata::get(ConstantInt::get(
Type::getInt32Ty(Ctx), EP.NumThreadsZ))};
MDVals.emplace_back(MDNode::get(Ctx, NumThreadVals));
+
+ // Handle optional "hlsl.wavesize". The fields are optionally represented
+ // if they are non-zero.
+ if (EP.WaveSizeMin != 0) {
+ bool IsWaveRange = VersionTuple(6, 8) <= MMDI.ShaderModelVersion;
+ bool IsWaveSize =
+ !IsWaveRange && VersionTuple(6, 6) <= MMDI.ShaderModelVersion;
+
+ if (!IsWaveRange && !IsWaveSize) {
+ reportError(M, "Shader model 6.6 or greater is required to specify "
+ "the \"hlsl.wavesize\" function attribute");
+ return nullptr;
+ }
+
+ if (EP.WaveSizeMax && !IsWaveRange) {
+ reportError(
+ M, "Shader model 6.8 or greater is required to specify "
+ "wave size range values of the \"hlsl.wavesize\" function "
+ "attribute");
+ return nullptr;
+ }
+
+ EntryPropsTag Tag =
+ IsWaveSize ? EntryPropsTag::WaveSize : EntryPropsTag::WaveRange;
+ MDVals.emplace_back(ConstantAsMetadata::get(
+ ConstantInt::get(Type::getInt32Ty(Ctx), static_cast<int>(Tag))));
+
+ SmallVector<Metadata *> WaveSizeVals = {ConstantAsMetadata::get(
+ ConstantInt::get(Type::getInt32Ty(Ctx), EP.WaveSizeMin))};
+ if (IsWaveRange) {
+ WaveSizeVals.push_back(ConstantAsMetadata::get(
+ ConstantInt::get(Type::getInt32Ty(Ctx), EP.WaveSizeMax)));
+ WaveSizeVals.push_back(ConstantAsMetadata::get(
+ ConstantInt::get(Type::getInt32Ty(Ctx), EP.WaveSizePref)));
+ }
+
+ MDVals.emplace_back(MDNode::get(Ctx, WaveSizeVals));
+ }
}
}
+
if (MDVals.empty())
return nullptr;
return MDNode::get(Ctx, MDVals);
@@ -236,12 +278,11 @@ static MDTuple *constructEntryMetadata(const Function *EntryFn,
return MDNode::get(Ctx, MDVals);
}
-static MDTuple *emitEntryMD(const EntryProperties &EP, MDTuple *Signatures,
- MDNode *MDResources,
+static MDTuple *emitEntryMD(Module &M, const EntryProperties &EP,
+ MDTuple *Signatures, MDNode *MDResources,
const uint64_t EntryShaderFlags,
- const Triple::EnvironmentType ShaderProfile) {
- MDTuple *Properties =
- getEntryPropAsMetadata(EP, EntryShaderFlags, ShaderProfile);
+ const ModuleMetadataInfo &MMDI) {
+ MDTuple *Properties = getEntryPropAsMetadata(M, EP, EntryShaderFlags, MMDI);
return constructEntryMetadata(EP.Entry, Signatures, MDResources, Properties,
EP.Entry->getContext());
}
@@ -523,10 +564,8 @@ static void translateGlobalMetadata(Module &M, DXILResourceMap &DRM,
Twine(Triple::getEnvironmentTypeName(MMDI.ShaderProfile) +
"'"));
}
-
- EntryFnMDNodes.emplace_back(emitEntryMD(EntryProp, Signatures, ResourceMD,
- EntryShaderFlags,
- MMDI.ShaderProfile));
+ EntryFnMDNodes.emplace_back(emitEntryMD(
+ M, EntryProp, Signatures, ResourceMD, EntryShaderFlags, MMDI));
}
NamedMDNode *EntryPointsNamedMD =
diff --git a/llvm/test/CodeGen/DirectX/wavesize-md-errs.ll b/llvm/test/CodeGen/DirectX/wavesize-md-errs.ll
new file mode 100644
index 0000000000000..9016c5d7e8d44
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/wavesize-md-errs.ll
@@ -0,0 +1,31 @@
+; RUN: split-file %s %t
+; RUN: not opt -S --dxil-translate-metadata %t/low-sm.ll 2>&1 | FileCheck %t/low-sm.ll
+; RUN: not opt -S --dxil-translate-metadata %t/low-sm-for-range.ll 2>&1 | FileCheck %t/low-sm-for-range.ll
+
+; Test that wavesize metadata is only allowed on applicable shader model versions
+
+;--- low-sm.ll
+
+; CHECK: Shader model 6.6 or greater is required to specify the "hlsl.wavesize" function attribute
+
+target triple = "dxil-unknown-shadermodel6.5-compute"
+
+define void @main() #0 {
+entry:
+ ret void
+}
+
+attributes #0 = { "hlsl.wavesize"="16,0,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
+
+;--- low-sm-for-range.ll
+
+; CHECK: Shader model 6.8 or greater is required to specify wave size range values of the "hlsl.wavesize" function attribute
+
+target triple = "dxil-unknown-shadermodel6.7-compute"
+
+define void @main() #0 {
+entry:
+ ret void
+}
+
+attributes #0 = { "hlsl.wavesize"="16,32,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
diff --git a/llvm/test/CodeGen/DirectX/wavesize-md-valid.ll b/llvm/test/CodeGen/DirectX/wavesize-md-valid.ll
new file mode 100644
index 0000000000000..63e8a59eb2648
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/wavesize-md-valid.ll
@@ -0,0 +1,71 @@
+; RUN: split-file %s %t
+; RUN: opt -S --dxil-translate-metadata %t/only.ll | FileCheck %t/only.ll
+; RUN: opt -S --dxil-translate-metadata %t/min.ll | FileCheck %t/min.ll
+; RUN: opt -S --dxil-translate-metadata %t/max.ll | FileCheck %t/max.ll
+; RUN: opt -S --dxil-translate-metadata %t/pref.ll | FileCheck %t/pref.ll
+
+; Test that wave size/range metadata is correctly generated with the correct tag
+
+;--- only.ll
+
+; CHECK: !dx.entryPoints = !{![[#ENTRY:]]}
+; CHECK: ![[#ENTRY]] = !{ptr @main, !"main", null, null, ![[#PROPS:]]}
+; CHECK: ![[#PROPS]] = !{{{.*}}i32 11, ![[#WAVE_SIZE:]]{{.*}}}
+; CHECK: ![[#WAVE_SIZE]] = !{i32 16}
+
+target triple = "dxil-unknown-shadermodel6.6-compute"
+
+define void @main() #0 {
+entry:
+ ret void
+}
+
+attributes #0 = { "hlsl.wavesize"="16,0,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
+
+;--- min.ll
+
+; CHECK: !dx.entryPoints = !{![[#ENTRY:]]}
+; CHECK: ![[#ENTRY]] = !{ptr @main, !"main", null, null, ![[#PROPS:]]}
+; CHECK: ![[#PROPS]] = !{{{.*}}i32 23, ![[#WAVE_SIZE:]]{{.*}}}
+; CHECK: ![[#WAVE_SIZE]] = !{i32 16, i32 0, i32 0}
+
+target triple = "dxil-unknown-shadermodel6.8-compute"
+
+define void @main() #0 {
+entry:
+ ret void
+}
+
+attributes #0 = { "hlsl.wavesize"="16,0,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
+
+;--- max.ll
+
+; CHECK: !dx.entryPoints = !{![[#ENTRY:]]}
+; CHECK: ![[#ENTRY]] = !{ptr @main, !"main", null, null, ![[#PROPS:]]}
+; CHECK: ![[#PROPS]] = !{{{.*}}i32 23, ![[#WAVE_SIZE:]]{{.*}}}
+; CHECK: ![[#WAVE_SIZE]] = !{i32 16, i32 32, i32 0}
+
+target triple = "dxil-unknown-shadermodel6.8-compute"
+
+define void @main() #0 {
+entry:
+ ret void
+}
+
+attributes #0 = { "hlsl.wavesize"="16,32,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
+
+;--- pref.ll
+
+; CHECK: !dx.entryPoints = !{![[#ENTRY:]]}
+; CHECK: ![[#ENTRY]] = !{ptr @main, !"main", null, null, ![[#PROPS:]]}
+; CHECK: ![[#PROPS]] = !{{{.*}}i32 23, ![[#WAVE_SIZE:]]{{.*}}}
+; CHECK: ![[#WAVE_SIZE]] = !{i32 16, i32 64, i32 32}
+
+target triple = "dxil-unknown-shadermodel6.8-compute"
+
+define void @main() #0 {
+entry:
+ ret void
+}
+
+attributes #0 = { "hlsl.wavesize"="16,64,32" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
More information about the llvm-commits
mailing list