[llvm] [DirectX] Emit `hlsl.wavesize` function attribute as entry property metadata (PR #165624)

Finn Plummer via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 29 14:02:56 PDT 2025


https://github.com/inbelic updated https://github.com/llvm/llvm-project/pull/165624

>From 793019a5052f27d770f9874f3cea311c9968339d Mon Sep 17 00:00:00 2001
From: Finn Plummer <mail at inbelic.dev>
Date: Wed, 29 Oct 2025 13:45:14 -0700
Subject: [PATCH 1/2] [DirectX] Emit `WaveSize` function attribute metadata

---
 .../llvm/Analysis/DXILMetadataAnalysis.h      |  3 +
 llvm/lib/Analysis/DXILMetadataAnalysis.cpp    | 16 +++++
 .../Target/DirectX/DXILTranslateMetadata.cpp  | 67 +++++++++++++----
 llvm/test/CodeGen/DirectX/wavesize-md-errs.ll | 31 ++++++++
 .../test/CodeGen/DirectX/wavesize-md-valid.ll | 71 +++++++++++++++++++
 5 files changed, 174 insertions(+), 14 deletions(-)
 create mode 100644 llvm/test/CodeGen/DirectX/wavesize-md-errs.ll
 create mode 100644 llvm/test/CodeGen/DirectX/wavesize-md-valid.ll

diff --git a/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h b/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h
index cb535ac14f1c6..a1b030c157eae 100644
--- a/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h
+++ b/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h
@@ -27,6 +27,9 @@ struct EntryProperties {
   unsigned NumThreadsX{0}; // X component
   unsigned NumThreadsY{0}; // Y component
   unsigned NumThreadsZ{0}; // Z component
+  unsigned WaveSizeMin{0}; // Minimum component
+  unsigned WaveSizeMax{0}; // Maximum component
+  unsigned WaveSizePref{0}; // Preferred component
 
   EntryProperties(const Function *Fn = nullptr) : Entry(Fn) {};
 };
diff --git a/llvm/lib/Analysis/DXILMetadataAnalysis.cpp b/llvm/lib/Analysis/DXILMetadataAnalysis.cpp
index 23f1aa82ae8a3..bd77cba385667 100644
--- a/llvm/lib/Analysis/DXILMetadataAnalysis.cpp
+++ b/llvm/lib/Analysis/DXILMetadataAnalysis.cpp
@@ -66,6 +66,22 @@ static ModuleMetadataInfo collectMetadataInfo(Module &M) {
       Success = llvm::to_integer(NumThreadsVec[2], EFP.NumThreadsZ, 10);
       assert(Success && "Failed to parse Z component of numthreads");
     }
+    // Get wavesize attribute value, if one exists
+    StringRef WaveSizeStr =
+        F.getFnAttribute("hlsl.wavesize").getValueAsString();
+    if (!WaveSizeStr.empty()) {
+      SmallVector<StringRef> WaveSizeVec;
+      WaveSizeStr.split(WaveSizeVec, ',');
+      assert(WaveSizeVec.size() == 3 && "Invalid wavesize specified");
+      // Read in the three component values of numthreads
+      [[maybe_unused]] bool Success =
+          llvm::to_integer(WaveSizeVec[0], EFP.WaveSizeMin, 10);
+      assert(Success && "Failed to parse Min component of wavesize");
+      Success = llvm::to_integer(WaveSizeVec[1], EFP.WaveSizeMax, 10);
+      assert(Success && "Failed to parse Max component of wavesize");
+      Success = llvm::to_integer(WaveSizeVec[2], EFP.WaveSizePref, 10);
+      assert(Success && "Failed to parse Preferred component of wavesize");
+    }
     MMDAI.EntryPropertyVec.push_back(EFP);
   }
   return MMDAI;
diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
index cf8b833b3e42e..682847a94c6fb 100644
--- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
+++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
@@ -82,6 +82,7 @@ enum class EntryPropsTag {
   ASStateTag,
   WaveSize,
   EntryRootSig,
+  WaveRange = 23,
 };
 
 } // namespace
@@ -177,30 +178,32 @@ getTagValueAsMetadata(EntryPropsTag Tag, uint64_t Value, LLVMContext &Ctx) {
   case EntryPropsTag::ASStateTag:
   case EntryPropsTag::WaveSize:
   case EntryPropsTag::EntryRootSig:
+  case EntryPropsTag::WaveRange:
     llvm_unreachable("NYI: Unhandled entry property tag");
   }
   return MDVals;
 }
 
-static MDTuple *
-getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags,
-                       const Triple::EnvironmentType ShaderProfile) {
+static MDTuple *getEntryPropAsMetadata(Module &M, const EntryProperties &EP,
+                                       uint64_t EntryShaderFlags,
+                                       const ModuleMetadataInfo &MMDI) {
   SmallVector<Metadata *> MDVals;
   LLVMContext &Ctx = EP.Entry->getContext();
   if (EntryShaderFlags != 0)
     MDVals.append(getTagValueAsMetadata(EntryPropsTag::ShaderFlags,
-                                        EntryShaderFlags, Ctx));
+                                        MMDI.ShaderProfile, Ctx));
 
   if (EP.Entry != nullptr) {
     // FIXME: support more props.
     // See https://github.com/llvm/llvm-project/issues/57948.
     // Add shader kind for lib entries.
-    if (ShaderProfile == Triple::EnvironmentType::Library &&
+    if (MMDI.ShaderProfile == Triple::EnvironmentType::Library &&
         EP.ShaderStage != Triple::EnvironmentType::Library)
       MDVals.append(getTagValueAsMetadata(EntryPropsTag::ShaderKind,
                                           getShaderStage(EP.ShaderStage), Ctx));
 
     if (EP.ShaderStage == Triple::EnvironmentType::Compute) {
+      // Handle mandatory "hlsl.numthreads"
       MDVals.emplace_back(ConstantAsMetadata::get(ConstantInt::get(
           Type::getInt32Ty(Ctx), static_cast<int>(EntryPropsTag::NumThreads))));
       Metadata *NumThreadVals[] = {ConstantAsMetadata::get(ConstantInt::get(
@@ -210,8 +213,47 @@ getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags,
                                    ConstantAsMetadata::get(ConstantInt::get(
                                        Type::getInt32Ty(Ctx), EP.NumThreadsZ))};
       MDVals.emplace_back(MDNode::get(Ctx, NumThreadVals));
+
+      // Handle optional "hlsl.wavesize". The fields are optionally represented
+      // if they are non-zero.
+      if (EP.WaveSizeMin != 0) {
+        bool IsWaveRange = VersionTuple(6, 8) <= MMDI.ShaderModelVersion;
+        bool IsWaveSize =
+            !IsWaveRange && VersionTuple(6, 6) <= MMDI.ShaderModelVersion;
+
+        if (!IsWaveRange && !IsWaveSize) {
+          reportError(M, "Shader model 6.6 or greater is required to specify "
+                         "the \"hlsl.wavesize\" function attribute");
+          return nullptr;
+        }
+
+        if (EP.WaveSizeMax && !IsWaveRange) {
+          reportError(
+              M, "Shader model 6.8 or greater is required to specify "
+                 "wave size range values of the \"hlsl.wavesize\" function "
+                 "attribute");
+          return nullptr;
+        }
+
+        EntryPropsTag Tag =
+            IsWaveSize ? EntryPropsTag::WaveSize : EntryPropsTag::WaveRange;
+        MDVals.emplace_back(ConstantAsMetadata::get(
+            ConstantInt::get(Type::getInt32Ty(Ctx), static_cast<int>(Tag))));
+
+        SmallVector<Metadata *> WaveSizeVals = {ConstantAsMetadata::get(
+            ConstantInt::get(Type::getInt32Ty(Ctx), EP.WaveSizeMin))};
+        if (IsWaveRange) {
+          WaveSizeVals.push_back(ConstantAsMetadata::get(
+              ConstantInt::get(Type::getInt32Ty(Ctx), EP.WaveSizeMax)));
+          WaveSizeVals.push_back(ConstantAsMetadata::get(
+              ConstantInt::get(Type::getInt32Ty(Ctx), EP.WaveSizePref)));
+        }
+
+        MDVals.emplace_back(MDNode::get(Ctx, WaveSizeVals));
+      }
     }
   }
+
   if (MDVals.empty())
     return nullptr;
   return MDNode::get(Ctx, MDVals);
@@ -236,12 +278,11 @@ static MDTuple *constructEntryMetadata(const Function *EntryFn,
   return MDNode::get(Ctx, MDVals);
 }
 
-static MDTuple *emitEntryMD(const EntryProperties &EP, MDTuple *Signatures,
-                            MDNode *MDResources,
+static MDTuple *emitEntryMD(Module &M, const EntryProperties &EP,
+                            MDTuple *Signatures, MDNode *MDResources,
                             const uint64_t EntryShaderFlags,
-                            const Triple::EnvironmentType ShaderProfile) {
-  MDTuple *Properties =
-      getEntryPropAsMetadata(EP, EntryShaderFlags, ShaderProfile);
+                            const ModuleMetadataInfo &MMDI) {
+  MDTuple *Properties = getEntryPropAsMetadata(M, EP, EntryShaderFlags, MMDI);
   return constructEntryMetadata(EP.Entry, Signatures, MDResources, Properties,
                                 EP.Entry->getContext());
 }
@@ -523,10 +564,8 @@ static void translateGlobalMetadata(Module &M, DXILResourceMap &DRM,
                    Twine(Triple::getEnvironmentTypeName(MMDI.ShaderProfile) +
                          "'"));
     }
-
-    EntryFnMDNodes.emplace_back(emitEntryMD(EntryProp, Signatures, ResourceMD,
-                                            EntryShaderFlags,
-                                            MMDI.ShaderProfile));
+    EntryFnMDNodes.emplace_back(emitEntryMD(
+        M, EntryProp, Signatures, ResourceMD, EntryShaderFlags, MMDI));
   }
 
   NamedMDNode *EntryPointsNamedMD =
diff --git a/llvm/test/CodeGen/DirectX/wavesize-md-errs.ll b/llvm/test/CodeGen/DirectX/wavesize-md-errs.ll
new file mode 100644
index 0000000000000..9016c5d7e8d44
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/wavesize-md-errs.ll
@@ -0,0 +1,31 @@
+; RUN: split-file %s %t
+; RUN: not opt -S --dxil-translate-metadata %t/low-sm.ll 2>&1 | FileCheck %t/low-sm.ll
+; RUN: not opt -S --dxil-translate-metadata %t/low-sm-for-range.ll 2>&1 | FileCheck %t/low-sm-for-range.ll
+
+; Test that wavesize metadata is only allowed on applicable shader model versions
+
+;--- low-sm.ll
+
+; CHECK: Shader model 6.6 or greater is required to specify the "hlsl.wavesize" function attribute
+
+target triple = "dxil-unknown-shadermodel6.5-compute"
+
+define void @main() #0 {
+entry:
+  ret void
+}
+
+attributes #0 = { "hlsl.wavesize"="16,0,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
+
+;--- low-sm-for-range.ll
+
+; CHECK: Shader model 6.8 or greater is required to specify wave size range values of the "hlsl.wavesize" function attribute
+
+target triple = "dxil-unknown-shadermodel6.7-compute"
+
+define void @main() #0 {
+entry:
+  ret void
+}
+
+attributes #0 = { "hlsl.wavesize"="16,32,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
diff --git a/llvm/test/CodeGen/DirectX/wavesize-md-valid.ll b/llvm/test/CodeGen/DirectX/wavesize-md-valid.ll
new file mode 100644
index 0000000000000..63e8a59eb2648
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/wavesize-md-valid.ll
@@ -0,0 +1,71 @@
+; RUN: split-file %s %t
+; RUN: opt -S --dxil-translate-metadata %t/only.ll | FileCheck %t/only.ll
+; RUN: opt -S --dxil-translate-metadata %t/min.ll | FileCheck %t/min.ll
+; RUN: opt -S --dxil-translate-metadata %t/max.ll | FileCheck %t/max.ll
+; RUN: opt -S --dxil-translate-metadata %t/pref.ll | FileCheck %t/pref.ll
+
+; Test that wave size/range metadata is correctly generated with the correct tag
+
+;--- only.ll
+
+; CHECK: !dx.entryPoints = !{![[#ENTRY:]]}
+; CHECK: ![[#ENTRY]] = !{ptr @main, !"main", null, null, ![[#PROPS:]]}
+; CHECK: ![[#PROPS]] = !{{{.*}}i32 11, ![[#WAVE_SIZE:]]{{.*}}}
+; CHECK: ![[#WAVE_SIZE]] = !{i32 16}
+
+target triple = "dxil-unknown-shadermodel6.6-compute"
+
+define void @main() #0 {
+entry:
+  ret void
+}
+
+attributes #0 = { "hlsl.wavesize"="16,0,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
+
+;--- min.ll
+
+; CHECK: !dx.entryPoints = !{![[#ENTRY:]]}
+; CHECK: ![[#ENTRY]] = !{ptr @main, !"main", null, null, ![[#PROPS:]]}
+; CHECK: ![[#PROPS]] = !{{{.*}}i32 23, ![[#WAVE_SIZE:]]{{.*}}}
+; CHECK: ![[#WAVE_SIZE]] = !{i32 16, i32 0, i32 0}
+
+target triple = "dxil-unknown-shadermodel6.8-compute"
+
+define void @main() #0 {
+entry:
+  ret void
+}
+
+attributes #0 = { "hlsl.wavesize"="16,0,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
+
+;--- max.ll
+
+; CHECK: !dx.entryPoints = !{![[#ENTRY:]]}
+; CHECK: ![[#ENTRY]] = !{ptr @main, !"main", null, null, ![[#PROPS:]]}
+; CHECK: ![[#PROPS]] = !{{{.*}}i32 23, ![[#WAVE_SIZE:]]{{.*}}}
+; CHECK: ![[#WAVE_SIZE]] = !{i32 16, i32 32, i32 0}
+
+target triple = "dxil-unknown-shadermodel6.8-compute"
+
+define void @main() #0 {
+entry:
+  ret void
+}
+
+attributes #0 = { "hlsl.wavesize"="16,32,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
+
+;--- pref.ll
+
+; CHECK: !dx.entryPoints = !{![[#ENTRY:]]}
+; CHECK: ![[#ENTRY]] = !{ptr @main, !"main", null, null, ![[#PROPS:]]}
+; CHECK: ![[#PROPS]] = !{{{.*}}i32 23, ![[#WAVE_SIZE:]]{{.*}}}
+; CHECK: ![[#WAVE_SIZE]] = !{i32 16, i32 64, i32 32}
+
+target triple = "dxil-unknown-shadermodel6.8-compute"
+
+define void @main() #0 {
+entry:
+  ret void
+}
+
+attributes #0 = { "hlsl.wavesize"="16,64,32" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }

>From ff1838eed00dccf45726eaa45b1de14277e951e3 Mon Sep 17 00:00:00 2001
From: Finn Plummer <mail at inbelic.dev>
Date: Wed, 29 Oct 2025 14:02:42 -0700
Subject: [PATCH 2/2] small corrections

---
 llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
index 682847a94c6fb..e1a472fe57642 100644
--- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
+++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
@@ -191,7 +191,7 @@ static MDTuple *getEntryPropAsMetadata(Module &M, const EntryProperties &EP,
   LLVMContext &Ctx = EP.Entry->getContext();
   if (EntryShaderFlags != 0)
     MDVals.append(getTagValueAsMetadata(EntryPropsTag::ShaderFlags,
-                                        MMDI.ShaderProfile, Ctx));
+                                        EntryShaderFlags, Ctx));
 
   if (EP.Entry != nullptr) {
     // FIXME: support more props.
@@ -227,6 +227,7 @@ static MDTuple *getEntryPropAsMetadata(Module &M, const EntryProperties &EP,
           return nullptr;
         }
 
+        // A range is being specified if EP.WaveSizeMax != 0
         if (EP.WaveSizeMax && !IsWaveRange) {
           reportError(
               M, "Shader model 6.8 or greater is required to specify "



More information about the llvm-commits mailing list