[llvm-branch-commits] [clang] [llvm] [DirectX] Validate registers are bound to root signature (PR #146785)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Jul 4 17:37:26 PDT 2025


https://github.com/joaosaffran updated https://github.com/llvm/llvm-project/pull/146785

>From a49aa19297811e5800ffce364d8d6a225109d93f Mon Sep 17 00:00:00 2001
From: joaosaffran <joao.saffran at microsoft.com>
Date: Thu, 26 Jun 2025 19:28:01 +0000
Subject: [PATCH 1/8] refactoring

---
 .../lib/Target/DirectX/DXContainerGlobals.cpp |  4 ++-
 llvm/lib/Target/DirectX/DXILRootSignature.cpp | 14 +++-----
 llvm/lib/Target/DirectX/DXILRootSignature.h   | 33 +++++++++----------
 3 files changed, 23 insertions(+), 28 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
index 6c8ae8eaaea77..e076283b65193 100644
--- a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
+++ b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
@@ -160,11 +160,13 @@ void DXContainerGlobals::addRootSignature(Module &M,
 
   assert(MMI.EntryPropertyVec.size() == 1);
 
+  auto &RSA = getAnalysis<RootSignatureAnalysisWrapper>().getRSInfo();
   auto &RSA = getAnalysis<RootSignatureAnalysisWrapper>().getRSInfo();
   const Function *EntryFunction = MMI.EntryPropertyVec[0].Entry;
   const auto &RS = RSA.getDescForFunction(EntryFunction);
+  const auto &RS = RSA.getDescForFunction(EntryFunction);
 
-  if (!RS)
+  if (!RS )
     return;
 
   SmallString<256> Data;
diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
index 5a53ea8a3631b..4094df160ef6f 100644
--- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp
+++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
@@ -554,12 +554,9 @@ analyzeModule(Module &M) {
 
 AnalysisKey RootSignatureAnalysis::Key;
 
-RootSignatureAnalysis::Result
-RootSignatureAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
-  if (!AnalysisResult)
-    AnalysisResult = std::make_unique<RootSignatureBindingInfo>(
-        RootSignatureBindingInfo(analyzeModule(M)));
-  return *AnalysisResult;
+RootSignatureBindingInfo RootSignatureAnalysis::run(Module &M,
+                                                    ModuleAnalysisManager &AM) {
+  return RootSignatureBindingInfo(analyzeModule(M));
 }
 
 //===----------------------------------------------------------------------===//
@@ -638,9 +635,8 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
 
 //===----------------------------------------------------------------------===//
 bool RootSignatureAnalysisWrapper::runOnModule(Module &M) {
-  if (!FuncToRsMap)
-    FuncToRsMap = std::make_unique<RootSignatureBindingInfo>(
-        RootSignatureBindingInfo(analyzeModule(M)));
+  FuncToRsMap = std::make_unique<RootSignatureBindingInfo>(
+      RootSignatureBindingInfo(analyzeModule(M)));
   return false;
 }
 
diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.h b/llvm/lib/Target/DirectX/DXILRootSignature.h
index 3832182277050..24b1a8d3d2abe 100644
--- a/llvm/lib/Target/DirectX/DXILRootSignature.h
+++ b/llvm/lib/Target/DirectX/DXILRootSignature.h
@@ -37,30 +37,28 @@ enum class RootSignatureElementKind {
 };
 
 class RootSignatureBindingInfo {
-private:
-  SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> FuncToRsMap;
+  private:
+    SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> FuncToRsMap;
 
-public:
+  public:
   using iterator =
-      SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc>::iterator;
+        SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc>::iterator;
 
-  RootSignatureBindingInfo() = default;
-  RootSignatureBindingInfo(
-      SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> Map)
-      : FuncToRsMap(Map) {};
+  RootSignatureBindingInfo () = default;
+  RootSignatureBindingInfo(SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> Map) : FuncToRsMap(Map) {};
 
   iterator find(const Function *F) { return FuncToRsMap.find(F); }
 
   iterator end() { return FuncToRsMap.end(); }
 
-  std::optional<mcdxbc::RootSignatureDesc>
-  getDescForFunction(const Function *F) {
+  std::optional<mcdxbc::RootSignatureDesc> getDescForFunction(const Function* F) {
     const auto FuncRs = find(F);
     if (FuncRs == end())
       return std::nullopt;
 
     return FuncRs->second;
   }
+  
 };
 
 class RootSignatureAnalysis : public AnalysisInfoMixin<RootSignatureAnalysis> {
@@ -68,14 +66,13 @@ class RootSignatureAnalysis : public AnalysisInfoMixin<RootSignatureAnalysis> {
   static AnalysisKey Key;
 
 public:
-  RootSignatureAnalysis() = default;
-
-  using Result = RootSignatureBindingInfo;
 
-  Result run(Module &M, ModuleAnalysisManager &AM);
+RootSignatureAnalysis() = default;
 
-private:
-  std::unique_ptr<RootSignatureBindingInfo> AnalysisResult;
+  using Result = RootSignatureBindingInfo;
+  
+  RootSignatureBindingInfo
+  run(Module &M, ModuleAnalysisManager &AM);
 };
 
 /// Wrapper pass for the legacy pass manager.
@@ -92,8 +89,8 @@ class RootSignatureAnalysisWrapper : public ModulePass {
 
   RootSignatureAnalysisWrapper() : ModulePass(ID) {}
 
-  RootSignatureBindingInfo &getRSInfo() { return *FuncToRsMap; }
-
+  RootSignatureBindingInfo& getRSInfo() {return *FuncToRsMap;}
+  
   bool runOnModule(Module &M) override;
 
   void getAnalysisUsage(AnalysisUsage &AU) const override;

>From d90676feb6bfc0ca8bbdaee5c347ecc49e396b5b Mon Sep 17 00:00:00 2001
From: joaosaffran <joao.saffran at microsoft.com>
Date: Thu, 26 Jun 2025 21:37:11 +0000
Subject: [PATCH 2/8] init refactoring

---
 .../SemaHLSL/RootSignature-Validation.hlsl    | 42 +++++++++++++++++
 .../lib/Target/DirectX/DXContainerGlobals.cpp |  2 +-
 .../DXILPostOptimizationValidation.cpp        | 47 +++++++++++++++++--
 llvm/lib/Target/DirectX/DXILRootSignature.h   | 30 ++++++------
 4 files changed, 102 insertions(+), 19 deletions(-)
 create mode 100644 clang/test/SemaHLSL/RootSignature-Validation.hlsl

diff --git a/clang/test/SemaHLSL/RootSignature-Validation.hlsl b/clang/test/SemaHLSL/RootSignature-Validation.hlsl
new file mode 100644
index 0000000000000..8a4a97f87cb65
--- /dev/null
+++ b/clang/test/SemaHLSL/RootSignature-Validation.hlsl
@@ -0,0 +1,42 @@
+// RUN: %clang_dxc -triple dxil-pc-shadermodel6.3-library -x hlsl -o - %s -verify
+
+#define ROOT_SIGNATURE \
+    "RootFlags(ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT), " \
+    "CBV(b0, visibility=SHADER_VISIBILITY_ALL), " \
+    "DescriptorTable(SRV(t0, numDescriptors=3), visibility=SHADER_VISIBILITY_PIXEL), " \
+    "DescriptorTable(Sampler(s0, numDescriptors=2), visibility=SHADER_VISIBILITY_PIXEL), " \
+    "DescriptorTable(UAV(u0, numDescriptors=1), visibility=SHADER_VISIBILITY_ALL)"
+
+cbuffer CB : register(b3, space2) {
+  float a;
+}
+
+StructuredBuffer<int> In : register(t0);
+RWStructuredBuffer<int> Out : register(u0);
+
+RWBuffer<float> UAV : register(u3);
+
+RWBuffer<float> UAV1 : register(u2), UAV2 : register(u4);
+
+RWBuffer<float> UAV3 : register(space5);
+
+float f : register(c5);
+
+int4 intv : register(c2);
+
+double dar[5] :  register(c3);
+
+struct S {
+  int a;
+};
+
+S s : register(c10);
+
+// Compute Shader for UAV testing
+[numthreads(8, 8, 1)]
+[RootSignature(ROOT_SIGNATURE)]
+void CSMain(uint3 id : SV_DispatchThreadID)
+{
+    In[0] = id;
+    Out[0] = In[0];
+}
diff --git a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
index e076283b65193..5c763c24a210a 100644
--- a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
+++ b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
@@ -166,7 +166,7 @@ void DXContainerGlobals::addRootSignature(Module &M,
   const auto &RS = RSA.getDescForFunction(EntryFunction);
   const auto &RS = RSA.getDescForFunction(EntryFunction);
 
-  if (!RS )
+  if (!RS)
     return;
 
   SmallString<256> Data;
diff --git a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp
index 398dcbb8d1737..daf53fefe5f17 100644
--- a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp
+++ b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp
@@ -7,11 +7,14 @@
 //===----------------------------------------------------------------------===//
 
 #include "DXILPostOptimizationValidation.h"
+#include "DXILRootSignature.h"
 #include "DXILShaderFlags.h"
 #include "DirectX.h"
+#include "llvm/ADT/STLForwardCompat.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/Analysis/DXILMetadataAnalysis.h"
 #include "llvm/Analysis/DXILResource.h"
+#include "llvm/BinaryFormat/DXContainer.h"
 #include "llvm/IR/DiagnosticInfo.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicsDirectX.h"
@@ -85,7 +88,9 @@ static void reportOverlappingBinding(Module &M, DXILResourceMap &DRM) {
 }
 
 static void reportErrors(Module &M, DXILResourceMap &DRM,
-                         DXILResourceBindingInfo &DRBI) {
+                         DXILResourceBindingInfo &DRBI,
+                         RootSignatureBindingInfo &RSBI,
+                         dxil::ModuleMetadataInfo &MMI) {
   if (DRM.hasInvalidCounterDirection())
     reportInvalidDirection(M, DRM);
 
@@ -94,6 +99,30 @@ static void reportErrors(Module &M, DXILResourceMap &DRM,
 
   assert(!DRBI.hasImplicitBinding() && "implicit bindings should be handled in "
                                        "DXILResourceImplicitBinding pass");
+  // Assuming this is used to validate only the root signature assigned to the
+  // entry function.
+  std::optional<mcdxbc::RootSignatureDesc> RootSigDesc =
+      RSBI.getDescForFunction(MMI.EntryPropertyVec[0].Entry);
+  if (!RootSigDesc)
+    return;
+
+  for (const mcdxbc::RootParameterInfo &Info :
+       RootSigDesc->ParametersContainer) {
+    const auto &[Type, Loc] =
+        RootSigDesc->ParametersContainer.getTypeAndLocForParameter(
+            Info.Location);
+    switch (Type) {
+    case llvm::to_underlying(dxbc::RootParameterType::CBV):
+      dxbc::RTS0::v2::RootDescriptor Desc =
+          RootSigDesc->ParametersContainer.getRootDescriptor(Loc);
+
+      llvm::dxil::ResourceInfo::ResourceBinding Binding;
+      Binding.LowerBound = Desc.ShaderRegister;
+      Binding.Space = Desc.RegisterSpace;
+      Binding.Size = 1;
+      break;
+    }
+  }
 }
 } // namespace
 
@@ -101,7 +130,10 @@ PreservedAnalyses
 DXILPostOptimizationValidation::run(Module &M, ModuleAnalysisManager &MAM) {
   DXILResourceMap &DRM = MAM.getResult<DXILResourceAnalysis>(M);
   DXILResourceBindingInfo &DRBI = MAM.getResult<DXILResourceBindingAnalysis>(M);
-  reportErrors(M, DRM, DRBI);
+  RootSignatureBindingInfo &RSBI = MAM.getResult<RootSignatureAnalysis>(M);
+  ModuleMetadataInfo &MMI = MAM.getResult<DXILMetadataAnalysis>(M);
+
+  reportErrors(M, DRM, DRBI, RSBI, MMI);
   return PreservedAnalyses::all();
 }
 
@@ -113,7 +145,13 @@ class DXILPostOptimizationValidationLegacy : public ModulePass {
         getAnalysis<DXILResourceWrapperPass>().getResourceMap();
     DXILResourceBindingInfo &DRBI =
         getAnalysis<DXILResourceBindingWrapperPass>().getBindingInfo();
-    reportErrors(M, DRM, DRBI);
+
+    RootSignatureBindingInfo &RSBI =
+        getAnalysis<RootSignatureAnalysisWrapper>().getRSInfo();
+    dxil::ModuleMetadataInfo &MMI =
+        getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
+
+    reportErrors(M, DRM, DRBI, RSBI, MMI);
     return false;
   }
   StringRef getPassName() const override {
@@ -125,10 +163,13 @@ class DXILPostOptimizationValidationLegacy : public ModulePass {
   void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
     AU.addRequired<DXILResourceWrapperPass>();
     AU.addRequired<DXILResourceBindingWrapperPass>();
+    AU.addRequired<RootSignatureAnalysisWrapper>();
+    AU.addRequired<DXILMetadataAnalysisWrapperPass>();
     AU.addPreserved<DXILResourceWrapperPass>();
     AU.addPreserved<DXILResourceBindingWrapperPass>();
     AU.addPreserved<DXILMetadataAnalysisWrapperPass>();
     AU.addPreserved<ShaderFlagsAnalysisWrapper>();
+    AU.addPreserved<RootSignatureAnalysisWrapper>();
   }
 };
 char DXILPostOptimizationValidationLegacy::ID = 0;
diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.h b/llvm/lib/Target/DirectX/DXILRootSignature.h
index 24b1a8d3d2abe..ecfc577d1b97d 100644
--- a/llvm/lib/Target/DirectX/DXILRootSignature.h
+++ b/llvm/lib/Target/DirectX/DXILRootSignature.h
@@ -37,28 +37,30 @@ enum class RootSignatureElementKind {
 };
 
 class RootSignatureBindingInfo {
-  private:
-    SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> FuncToRsMap;
+private:
+  SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> FuncToRsMap;
 
-  public:
+public:
   using iterator =
-        SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc>::iterator;
+      SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc>::iterator;
 
-  RootSignatureBindingInfo () = default;
-  RootSignatureBindingInfo(SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> Map) : FuncToRsMap(Map) {};
+  RootSignatureBindingInfo() = default;
+  RootSignatureBindingInfo(
+      SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> Map)
+      : FuncToRsMap(Map){};
 
   iterator find(const Function *F) { return FuncToRsMap.find(F); }
 
   iterator end() { return FuncToRsMap.end(); }
 
-  std::optional<mcdxbc::RootSignatureDesc> getDescForFunction(const Function* F) {
+  std::optional<mcdxbc::RootSignatureDesc>
+  getDescForFunction(const Function *F) {
     const auto FuncRs = find(F);
     if (FuncRs == end())
       return std::nullopt;
 
     return FuncRs->second;
   }
-  
 };
 
 class RootSignatureAnalysis : public AnalysisInfoMixin<RootSignatureAnalysis> {
@@ -66,13 +68,11 @@ class RootSignatureAnalysis : public AnalysisInfoMixin<RootSignatureAnalysis> {
   static AnalysisKey Key;
 
 public:
-
-RootSignatureAnalysis() = default;
+  RootSignatureAnalysis() = default;
 
   using Result = RootSignatureBindingInfo;
-  
-  RootSignatureBindingInfo
-  run(Module &M, ModuleAnalysisManager &AM);
+
+  RootSignatureBindingInfo run(Module &M, ModuleAnalysisManager &AM);
 };
 
 /// Wrapper pass for the legacy pass manager.
@@ -89,8 +89,8 @@ class RootSignatureAnalysisWrapper : public ModulePass {
 
   RootSignatureAnalysisWrapper() : ModulePass(ID) {}
 
-  RootSignatureBindingInfo& getRSInfo() {return *FuncToRsMap;}
-  
+  RootSignatureBindingInfo &getRSInfo() { return *FuncToRsMap; }
+
   bool runOnModule(Module &M) override;
 
   void getAnalysisUsage(AnalysisUsage &AU) const override;

>From a04eb9ff37d20499f05c7b1cc0ab3187f729609b Mon Sep 17 00:00:00 2001
From: joaosaffran <joao.saffran at microsoft.com>
Date: Wed, 2 Jul 2025 17:58:56 +0000
Subject: [PATCH 3/8] adding validation

---
 .../SemaHLSL/RootSignature-Validation.hlsl    | 28 ++++---------
 .../DXILPostOptimizationValidation.cpp        | 42 +++++++++++++++----
 2 files changed, 43 insertions(+), 27 deletions(-)

diff --git a/clang/test/SemaHLSL/RootSignature-Validation.hlsl b/clang/test/SemaHLSL/RootSignature-Validation.hlsl
index 8a4a97f87cb65..62ba704b95c7d 100644
--- a/clang/test/SemaHLSL/RootSignature-Validation.hlsl
+++ b/clang/test/SemaHLSL/RootSignature-Validation.hlsl
@@ -1,42 +1,30 @@
-// RUN: %clang_dxc -triple dxil-pc-shadermodel6.3-library -x hlsl -o - %s -verify
 
 #define ROOT_SIGNATURE \
     "RootFlags(ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT), " \
-    "CBV(b0, visibility=SHADER_VISIBILITY_ALL), " \
-    "DescriptorTable(SRV(t0, numDescriptors=3), visibility=SHADER_VISIBILITY_PIXEL), " \
-    "DescriptorTable(Sampler(s0, numDescriptors=2), visibility=SHADER_VISIBILITY_PIXEL), " \
-    "DescriptorTable(UAV(u0, numDescriptors=1), visibility=SHADER_VISIBILITY_ALL)"
+    "CBV(b3, space=1, visibility=SHADER_VISIBILITY_ALL), " \
+    "DescriptorTable(SRV(t0, space=0, numDescriptors=1), visibility=SHADER_VISIBILITY_ALL), " \
+    "DescriptorTable(Sampler(s0, numDescriptors=2), visibility=SHADER_VISIBILITY_ALL), " \
+    "DescriptorTable(UAV(u0, numDescriptors=unbounded), visibility=SHADER_VISIBILITY_ALL)"
 
 cbuffer CB : register(b3, space2) {
   float a;
 }
 
-StructuredBuffer<int> In : register(t0);
+StructuredBuffer<int> In : register(t0, space0);
 RWStructuredBuffer<int> Out : register(u0);
 
 RWBuffer<float> UAV : register(u3);
 
 RWBuffer<float> UAV1 : register(u2), UAV2 : register(u4);
 
-RWBuffer<float> UAV3 : register(space5);
+RWBuffer<float> UAV3 : register(space0);
 
-float f : register(c5);
 
-int4 intv : register(c2);
-
-double dar[5] :  register(c3);
-
-struct S {
-  int a;
-};
-
-S s : register(c10);
 
 // Compute Shader for UAV testing
 [numthreads(8, 8, 1)]
 [RootSignature(ROOT_SIGNATURE)]
-void CSMain(uint3 id : SV_DispatchThreadID)
+void CSMain(uint id : SV_GroupID)
 {
-    In[0] = id;
-    Out[0] = In[0];
+    Out[0] = a + id + In[0] + UAV[0] + UAV1[0] + UAV3[0];
 }
diff --git a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp
index daf53fefe5f17..3e542e502c2d5 100644
--- a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp
+++ b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp
@@ -10,6 +10,7 @@
 #include "DXILRootSignature.h"
 #include "DXILShaderFlags.h"
 #include "DirectX.h"
+#include "llvm/ADT/IntervalMap.h"
 #include "llvm/ADT/STLForwardCompat.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/Analysis/DXILMetadataAnalysis.h"
@@ -86,7 +87,9 @@ static void reportOverlappingBinding(Module &M, DXILResourceMap &DRM) {
     }
   }
 }
-
+  uint64_t combine_uint32_to_uint64(uint32_t high, uint32_t low) {
+      return (static_cast<uint64_t>(high) << 32) | low;
+  }
 static void reportErrors(Module &M, DXILResourceMap &DRM,
                          DXILResourceBindingInfo &DRBI,
                          RootSignatureBindingInfo &RSBI,
@@ -101,18 +104,24 @@ static void reportErrors(Module &M, DXILResourceMap &DRM,
                                        "DXILResourceImplicitBinding pass");
   // Assuming this is used to validate only the root signature assigned to the
   // entry function.
+  //Start test stuff
+  if(MMI.EntryPropertyVec.size() == 0)
+    return;
+
   std::optional<mcdxbc::RootSignatureDesc> RootSigDesc =
       RSBI.getDescForFunction(MMI.EntryPropertyVec[0].Entry);
   if (!RootSigDesc)
     return;
 
-  for (const mcdxbc::RootParameterInfo &Info :
-       RootSigDesc->ParametersContainer) {
+  using MapT = llvm::IntervalMap<uint64_t, llvm::dxil::ResourceInfo::ResourceBinding, sizeof(llvm::dxil::ResourceInfo::ResourceBinding), llvm::IntervalMapInfo<uint64_t>>;
+  MapT::Allocator Allocator;
+  MapT BindingsMap(Allocator);
+  auto RSD = *RootSigDesc;
+   for (size_t I = 0; I < RSD.ParametersContainer.size(); I++) {
     const auto &[Type, Loc] =
-        RootSigDesc->ParametersContainer.getTypeAndLocForParameter(
-            Info.Location);
+        RootSigDesc->ParametersContainer.getTypeAndLocForParameter(I);
     switch (Type) {
-    case llvm::to_underlying(dxbc::RootParameterType::CBV):
+    case llvm::to_underlying(dxbc::RootParameterType::CBV):{
       dxbc::RTS0::v2::RootDescriptor Desc =
           RootSigDesc->ParametersContainer.getRootDescriptor(Loc);
 
@@ -120,8 +129,27 @@ static void reportErrors(Module &M, DXILResourceMap &DRM,
       Binding.LowerBound = Desc.ShaderRegister;
       Binding.Space = Desc.RegisterSpace;
       Binding.Size = 1;
+
+      BindingsMap.insert(combine_uint32_to_uint64(Binding.Space, Binding.LowerBound), combine_uint32_to_uint64(Binding.Space, Binding.LowerBound + Binding.Size -1), Binding);
       break;
     }
+    // case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable):{
+    //   mcdxbc::DescriptorTable Table =
+    //       RootSigDesc->ParametersContainer.getDescriptorTable(Loc);
+    //   for (const dxbc::RTS0::v2::DescriptorRange &Range : Table){
+    //     Range.
+    //   }
+      
+    //   break;
+    // }
+    }
+
+  }
+
+  for(const auto &CBuf : DRM.cbuffers()) {
+    auto Binding = CBuf.getBinding();
+    if(!BindingsMap.overlaps(combine_uint32_to_uint64(Binding.Space, Binding.LowerBound), combine_uint32_to_uint64(Binding.Space, Binding.LowerBound + Binding.Size -1)))
+      auto X = 1;
   }
 }
 } // namespace
@@ -146,7 +174,7 @@ class DXILPostOptimizationValidationLegacy : public ModulePass {
     DXILResourceBindingInfo &DRBI =
         getAnalysis<DXILResourceBindingWrapperPass>().getBindingInfo();
 
-    RootSignatureBindingInfo &RSBI =
+    RootSignatureBindingInfo& RSBI =
         getAnalysis<RootSignatureAnalysisWrapper>().getRSInfo();
     dxil::ModuleMetadataInfo &MMI =
         getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();

>From 5994b8f8f4ea24115a66c0046c8fc344905b41d4 Mon Sep 17 00:00:00 2001
From: joaosaffran <joao.saffran at microsoft.com>
Date: Wed, 2 Jul 2025 21:19:37 +0000
Subject: [PATCH 4/8] clean

---
 .../DXILPostOptimizationValidation.cpp        |  6 +----
 .../DirectX/DXILPostOptimizationValidation.h  |  3 +++
 llvm/lib/Target/DirectX/DXILRootSignature.h   | 24 +++++++++----------
 3 files changed, 15 insertions(+), 18 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp
index 3e542e502c2d5..4c29b56304391 100644
--- a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp
+++ b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp
@@ -10,12 +10,9 @@
 #include "DXILRootSignature.h"
 #include "DXILShaderFlags.h"
 #include "DirectX.h"
-#include "llvm/ADT/IntervalMap.h"
-#include "llvm/ADT/STLForwardCompat.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/Analysis/DXILMetadataAnalysis.h"
 #include "llvm/Analysis/DXILResource.h"
-#include "llvm/BinaryFormat/DXContainer.h"
 #include "llvm/IR/DiagnosticInfo.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicsDirectX.h"
@@ -173,8 +170,7 @@ class DXILPostOptimizationValidationLegacy : public ModulePass {
         getAnalysis<DXILResourceWrapperPass>().getResourceMap();
     DXILResourceBindingInfo &DRBI =
         getAnalysis<DXILResourceBindingWrapperPass>().getBindingInfo();
-
-    RootSignatureBindingInfo& RSBI =
+    RootSignatureBindingInfo &RSBI =
         getAnalysis<RootSignatureAnalysisWrapper>().getRSInfo();
     dxil::ModuleMetadataInfo &MMI =
         getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
diff --git a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h
index cb5e624514272..151843daf068d 100644
--- a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h
+++ b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h
@@ -14,6 +14,9 @@
 #ifndef LLVM_LIB_TARGET_DIRECTX_DXILPOSTOPTIMIZATIONVALIDATION_H
 #define LLVM_LIB_TARGET_DIRECTX_DXILPOSTOPTIMIZATIONVALIDATION_H
 
+#include "DXILRootSignature.h"
+#include "llvm/ADT/IntervalMap.h"
+#include "llvm/Analysis/DXILResource.h"
 #include "llvm/IR/PassManager.h"
 
 namespace llvm {
diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.h b/llvm/lib/Target/DirectX/DXILRootSignature.h
index ecfc577d1b97d..d0d5c7785bda3 100644
--- a/llvm/lib/Target/DirectX/DXILRootSignature.h
+++ b/llvm/lib/Target/DirectX/DXILRootSignature.h
@@ -37,30 +37,28 @@ enum class RootSignatureElementKind {
 };
 
 class RootSignatureBindingInfo {
-private:
-  SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> FuncToRsMap;
+  private:
+    SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> FuncToRsMap;
 
-public:
+  public:
   using iterator =
-      SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc>::iterator;
+        SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc>::iterator;
 
-  RootSignatureBindingInfo() = default;
-  RootSignatureBindingInfo(
-      SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> Map)
-      : FuncToRsMap(Map){};
+RootSignatureBindingInfo () = default;
+  RootSignatureBindingInfo(SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> Map) : FuncToRsMap(Map) {};
 
   iterator find(const Function *F) { return FuncToRsMap.find(F); }
 
   iterator end() { return FuncToRsMap.end(); }
 
-  std::optional<mcdxbc::RootSignatureDesc>
-  getDescForFunction(const Function *F) {
+  std::optional<mcdxbc::RootSignatureDesc> getDescForFunction(const Function *F) {
     const auto FuncRs = find(F);
     if (FuncRs == end())
       return std::nullopt;
 
     return FuncRs->second;
   }
+
 };
 
 class RootSignatureAnalysis : public AnalysisInfoMixin<RootSignatureAnalysis> {
@@ -68,7 +66,7 @@ class RootSignatureAnalysis : public AnalysisInfoMixin<RootSignatureAnalysis> {
   static AnalysisKey Key;
 
 public:
-  RootSignatureAnalysis() = default;
+RootSignatureAnalysis() = default;
 
   using Result = RootSignatureBindingInfo;
 
@@ -88,8 +86,8 @@ class RootSignatureAnalysisWrapper : public ModulePass {
   using Result = RootSignatureBindingInfo;
 
   RootSignatureAnalysisWrapper() : ModulePass(ID) {}
-
-  RootSignatureBindingInfo &getRSInfo() { return *FuncToRsMap; }
+  
+  RootSignatureBindingInfo& getRSInfo() {return *FuncToRsMap;}
 
   bool runOnModule(Module &M) override;
 

>From e8b14bf32e47cf8c059d2f492e57a602375ceeaa Mon Sep 17 00:00:00 2001
From: joaosaffran <joao.saffran at microsoft.com>
Date: Fri, 4 Jul 2025 02:03:26 +0000
Subject: [PATCH 5/8] implementing

---
 .../RootSignature-Validation-Fail.hlsl        |  35 ++++
 .../SemaHLSL/RootSignature-Validation.hlsl    |  11 +-
 .../DXILPostOptimizationValidation.cpp        | 166 +++++++++++++-----
 .../DirectX/DXILPostOptimizationValidation.h  |  88 ++++++++++
 llvm/lib/Target/DirectX/DXILRootSignature.h   |  24 +--
 .../RootSignature-DescriptorTable.ll          |   4 +-
 6 files changed, 271 insertions(+), 57 deletions(-)
 create mode 100644 clang/test/SemaHLSL/RootSignature-Validation-Fail.hlsl

diff --git a/clang/test/SemaHLSL/RootSignature-Validation-Fail.hlsl b/clang/test/SemaHLSL/RootSignature-Validation-Fail.hlsl
new file mode 100644
index 0000000000000..b590ed67e7085
--- /dev/null
+++ b/clang/test/SemaHLSL/RootSignature-Validation-Fail.hlsl
@@ -0,0 +1,35 @@
+// RUN: not %clang_dxc -T cs_6_6 -E CSMain %s 2>&1 | FileCheck %s
+
+// CHECK: error: register cbuffer (space=665, register=3) is not defined in Root Signature
+// CHECK: error: register srv (space=0, register=0) is not defined in Root Signature
+// CHECK: error: register uav (space=0, register=4294967295) is not defined in Root Signature
+
+
+#define ROOT_SIGNATURE \
+    "CBV(b3, space=666, visibility=SHADER_VISIBILITY_ALL), " \
+    "DescriptorTable(SRV(t0, space=0, numDescriptors=1), visibility=SHADER_VISIBILITY_VERTEX), " \
+    "DescriptorTable(Sampler(s0, numDescriptors=2), visibility=SHADER_VISIBILITY_ALL), " \
+    "DescriptorTable(UAV(u0, numDescriptors=unbounded), visibility=SHADER_VISIBILITY_ALL)"
+
+cbuffer CB : register(b3, space665) {
+  float a;
+}
+
+StructuredBuffer<int> In : register(t0, space0);
+RWStructuredBuffer<int> Out : register(u0);
+
+RWBuffer<float> UAV : register(u4294967295);
+
+RWBuffer<float> UAV1 : register(u2), UAV2 : register(u4);
+
+RWBuffer<float> UAV3 : register(space0);
+
+
+
+// Compute Shader for UAV testing
+[numthreads(8, 8, 1)]
+[RootSignature(ROOT_SIGNATURE)]
+void CSMain(uint id : SV_GroupID)
+{
+    Out[0] = a + id + In[0] + UAV[0] + UAV1[0] + UAV3[0];
+}
diff --git a/clang/test/SemaHLSL/RootSignature-Validation.hlsl b/clang/test/SemaHLSL/RootSignature-Validation.hlsl
index 62ba704b95c7d..5a7f5baf00619 100644
--- a/clang/test/SemaHLSL/RootSignature-Validation.hlsl
+++ b/clang/test/SemaHLSL/RootSignature-Validation.hlsl
@@ -1,19 +1,22 @@
+// RUN: %clang_dxc -T cs_6_6 -E CSMain %s 2>&1 
+
+// expected-no-diagnostics
+
 
 #define ROOT_SIGNATURE \
-    "RootFlags(ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT), " \
     "CBV(b3, space=1, visibility=SHADER_VISIBILITY_ALL), " \
     "DescriptorTable(SRV(t0, space=0, numDescriptors=1), visibility=SHADER_VISIBILITY_ALL), " \
-    "DescriptorTable(Sampler(s0, numDescriptors=2), visibility=SHADER_VISIBILITY_ALL), " \
+    "DescriptorTable(Sampler(s0, numDescriptors=2), visibility=SHADER_VISIBILITY_VERTEX), " \
     "DescriptorTable(UAV(u0, numDescriptors=unbounded), visibility=SHADER_VISIBILITY_ALL)"
 
-cbuffer CB : register(b3, space2) {
+cbuffer CB : register(b3, space1) {
   float a;
 }
 
 StructuredBuffer<int> In : register(t0, space0);
 RWStructuredBuffer<int> Out : register(u0);
 
-RWBuffer<float> UAV : register(u3);
+RWBuffer<float> UAV : register(u4294967294);
 
 RWBuffer<float> UAV1 : register(u2), UAV2 : register(u4);
 
diff --git a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp
index 4c29b56304391..23bb5d1a7f651 100644
--- a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp
+++ b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp
@@ -84,9 +84,57 @@ static void reportOverlappingBinding(Module &M, DXILResourceMap &DRM) {
     }
   }
 }
-  uint64_t combine_uint32_to_uint64(uint32_t high, uint32_t low) {
-      return (static_cast<uint64_t>(high) << 32) | low;
+
+static void reportRegNotBound(Module &M, Twine Type,
+                              ResourceInfo::ResourceBinding Binding) {
+  SmallString<128> Message;
+  raw_svector_ostream OS(Message);
+  OS << "register " << Type << " (space=" << Binding.Space
+     << ", register=" << Binding.LowerBound << ")"
+     << " is not defined in Root Signature";
+  M.getContext().diagnose(DiagnosticInfoGeneric(Message));
+}
+
+static dxbc::ShaderVisibility
+tripleToVisibility(llvm::Triple::EnvironmentType ET) {
+  assert((ET == Triple::Pixel || ET == Triple::Vertex ||
+          ET == Triple::Geometry || ET == Triple::Hull ||
+          ET == Triple::Domain || ET == Triple::Mesh ||
+          ET == Triple::Compute) &&
+         "Invalid Triple to shader stage conversion");
+
+  switch (ET) {
+  case Triple::Pixel:
+    return dxbc::ShaderVisibility::Pixel;
+  case Triple::Vertex:
+    return dxbc::ShaderVisibility::Vertex;
+  case Triple::Geometry:
+    return dxbc::ShaderVisibility::Geometry;
+  case Triple::Hull:
+    return dxbc::ShaderVisibility::Hull;
+  case Triple::Domain:
+    return dxbc::ShaderVisibility::Domain;
+  case Triple::Mesh:
+    return dxbc::ShaderVisibility::Mesh;
+  case Triple::Compute:
+    return dxbc::ShaderVisibility::All;
+  default:
+    llvm_unreachable("Invalid triple to shader stage conversion");
   }
+}
+
+std::optional<mcdxbc::RootSignatureDesc>
+getRootSignature(RootSignatureBindingInfo &RSBI,
+                 dxil::ModuleMetadataInfo &MMI) {
+  if (MMI.EntryPropertyVec.size() == 0)
+    return std::nullopt;
+  std::optional<mcdxbc::RootSignatureDesc> RootSigDesc =
+      RSBI.getDescForFunction(MMI.EntryPropertyVec[0].Entry);
+  if (!RootSigDesc)
+    return std::nullopt;
+  return RootSigDesc;
+}
+
 static void reportErrors(Module &M, DXILResourceMap &DRM,
                          DXILResourceBindingInfo &DRBI,
                          RootSignatureBindingInfo &RSBI,
@@ -99,57 +147,95 @@ static void reportErrors(Module &M, DXILResourceMap &DRM,
 
   assert(!DRBI.hasImplicitBinding() && "implicit bindings should be handled in "
                                        "DXILResourceImplicitBinding pass");
-  // Assuming this is used to validate only the root signature assigned to the
-  // entry function.
-  //Start test stuff
-  if(MMI.EntryPropertyVec.size() == 0)
-    return;
 
-  std::optional<mcdxbc::RootSignatureDesc> RootSigDesc =
-      RSBI.getDescForFunction(MMI.EntryPropertyVec[0].Entry);
-  if (!RootSigDesc)
-    return;
+  if (auto RSD = getRootSignature(RSBI, MMI)) {
+
+    RootSignatureBindingValidation Validation;
+    Validation.addRsBindingInfo(*RSD, tripleToVisibility(MMI.ShaderProfile));
+
+    for (const auto &CBuf : DRM.cbuffers()) {
+      ResourceInfo::ResourceBinding Binding = CBuf.getBinding();
+      if (!Validation.checkCregBinding(Binding))
+        reportRegNotBound(M, "cbuffer", Binding);
+    }
+
+    for (const auto &CBuf : DRM.srvs()) {
+      ResourceInfo::ResourceBinding Binding = CBuf.getBinding();
+      if (!Validation.checkTRegBinding(Binding))
+        reportRegNotBound(M, "srv", Binding);
+    }
 
-  using MapT = llvm::IntervalMap<uint64_t, llvm::dxil::ResourceInfo::ResourceBinding, sizeof(llvm::dxil::ResourceInfo::ResourceBinding), llvm::IntervalMapInfo<uint64_t>>;
-  MapT::Allocator Allocator;
-  MapT BindingsMap(Allocator);
-  auto RSD = *RootSigDesc;
-   for (size_t I = 0; I < RSD.ParametersContainer.size(); I++) {
+    for (const auto &CBuf : DRM.uavs()) {
+      ResourceInfo::ResourceBinding Binding = CBuf.getBinding();
+      if (!Validation.checkURegBinding(Binding))
+        reportRegNotBound(M, "uav", Binding);
+    }
+  }
+}
+} // namespace
+
+void RootSignatureBindingValidation::addRsBindingInfo(
+    mcdxbc::RootSignatureDesc &RSD, dxbc::ShaderVisibility Visibility) {
+  for (size_t I = 0; I < RSD.ParametersContainer.size(); I++) {
     const auto &[Type, Loc] =
-        RootSigDesc->ParametersContainer.getTypeAndLocForParameter(I);
+        RSD.ParametersContainer.getTypeAndLocForParameter(I);
+
+    const auto &Header = RSD.ParametersContainer.getHeader(I);
     switch (Type) {
-    case llvm::to_underlying(dxbc::RootParameterType::CBV):{
+    case llvm::to_underlying(dxbc::RootParameterType::SRV):
+    case llvm::to_underlying(dxbc::RootParameterType::UAV):
+    case llvm::to_underlying(dxbc::RootParameterType::CBV): {
       dxbc::RTS0::v2::RootDescriptor Desc =
-          RootSigDesc->ParametersContainer.getRootDescriptor(Loc);
+          RSD.ParametersContainer.getRootDescriptor(Loc);
 
-      llvm::dxil::ResourceInfo::ResourceBinding Binding;
-      Binding.LowerBound = Desc.ShaderRegister;
-      Binding.Space = Desc.RegisterSpace;
-      Binding.Size = 1;
+      if (Header.ShaderVisibility ==
+              llvm::to_underlying(dxbc::ShaderVisibility::All) ||
+          Header.ShaderVisibility == llvm::to_underlying(Visibility))
+        addRange(Desc, Type);
+      break;
+    }
+    case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
+      const mcdxbc::DescriptorTable &Table =
+          RSD.ParametersContainer.getDescriptorTable(Loc);
 
-      BindingsMap.insert(combine_uint32_to_uint64(Binding.Space, Binding.LowerBound), combine_uint32_to_uint64(Binding.Space, Binding.LowerBound + Binding.Size -1), Binding);
+      for (const dxbc::RTS0::v2::DescriptorRange &Range : Table.Ranges) {
+        if (Range.RangeType ==
+            llvm::to_underlying(dxbc::DescriptorRangeType::Sampler))
+          continue;
+
+        if (Header.ShaderVisibility ==
+                llvm::to_underlying(dxbc::ShaderVisibility::All) ||
+            Header.ShaderVisibility == llvm::to_underlying(Visibility))
+          addRange(Range);
+      }
       break;
     }
-    // case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable):{
-    //   mcdxbc::DescriptorTable Table =
-    //       RootSigDesc->ParametersContainer.getDescriptorTable(Loc);
-    //   for (const dxbc::RTS0::v2::DescriptorRange &Range : Table){
-    //     Range.
-    //   }
-      
-    //   break;
-    // }
     }
-
   }
+}
 
-  for(const auto &CBuf : DRM.cbuffers()) {
-    auto Binding = CBuf.getBinding();
-    if(!BindingsMap.overlaps(combine_uint32_to_uint64(Binding.Space, Binding.LowerBound), combine_uint32_to_uint64(Binding.Space, Binding.LowerBound + Binding.Size -1)))
-      auto X = 1;
-  }
+bool RootSignatureBindingValidation::checkCregBinding(
+    ResourceInfo::ResourceBinding Binding) {
+  return CRegBindingsMap.overlaps(
+      combineUint32ToUint64(Binding.Space, Binding.LowerBound),
+      combineUint32ToUint64(Binding.Space,
+                            Binding.LowerBound + Binding.Size - 1));
+}
+
+bool RootSignatureBindingValidation::checkTRegBinding(
+    ResourceInfo::ResourceBinding Binding) {
+  return TRegBindingsMap.overlaps(
+      combineUint32ToUint64(Binding.Space, Binding.LowerBound),
+      combineUint32ToUint64(Binding.Space, Binding.LowerBound + Binding.Size));
+}
+
+bool RootSignatureBindingValidation::checkURegBinding(
+    ResourceInfo::ResourceBinding Binding) {
+  return URegBindingsMap.overlaps(
+      combineUint32ToUint64(Binding.Space, Binding.LowerBound),
+      combineUint32ToUint64(Binding.Space,
+                            Binding.LowerBound + Binding.Size - 1));
 }
-} // namespace
 
 PreservedAnalyses
 DXILPostOptimizationValidation::run(Module &M, ModuleAnalysisManager &MAM) {
diff --git a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h
index 151843daf068d..58113bf9f93c7 100644
--- a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h
+++ b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h
@@ -21,6 +21,94 @@
 
 namespace llvm {
 
+static uint64_t combineUint32ToUint64(uint32_t High, uint32_t Low) {
+  return (static_cast<uint64_t>(High) << 32) | Low;
+}
+
+class RootSignatureBindingValidation {
+  using MapT =
+      llvm::IntervalMap<uint64_t, dxil::ResourceInfo::ResourceBinding,
+                        sizeof(llvm::dxil::ResourceInfo::ResourceBinding),
+                        llvm::IntervalMapInfo<uint64_t>>;
+
+private:
+  MapT::Allocator Allocator;
+  MapT CRegBindingsMap;
+  MapT TRegBindingsMap;
+  MapT URegBindingsMap;
+
+  void addRange(const dxbc::RTS0::v2::RootDescriptor &Desc, uint32_t Type) {
+    assert((Type == llvm::to_underlying(dxbc::RootParameterType::CBV) ||
+            Type == llvm::to_underlying(dxbc::RootParameterType::SRV) ||
+            Type == llvm::to_underlying(dxbc::RootParameterType::UAV)) &&
+           "Invalid Type");
+
+    llvm::dxil::ResourceInfo::ResourceBinding Binding;
+    Binding.LowerBound = Desc.ShaderRegister;
+    Binding.Space = Desc.RegisterSpace;
+    Binding.Size = 1;
+
+    uint64_t LowRange =
+        combineUint32ToUint64(Binding.Space, Binding.LowerBound);
+    uint64_t HighRange = combineUint32ToUint64(
+        Binding.Space, Binding.LowerBound + Binding.Size - 1);
+
+    switch (Type) {
+
+    case llvm::to_underlying(dxbc::RootParameterType::CBV):
+      CRegBindingsMap.insert(LowRange, HighRange, Binding);
+      return;
+    case llvm::to_underlying(dxbc::RootParameterType::SRV):
+      TRegBindingsMap.insert(LowRange, HighRange, Binding);
+      return;
+    case llvm::to_underlying(dxbc::RootParameterType::UAV):
+      URegBindingsMap.insert(LowRange, HighRange, Binding);
+      return;
+    }
+    llvm_unreachable("Invalid Type in add Range Method");
+  }
+
+  void addRange(const dxbc::RTS0::v2::DescriptorRange &Range) {
+
+    llvm::dxil::ResourceInfo::ResourceBinding Binding;
+    Binding.LowerBound = Range.BaseShaderRegister;
+    Binding.Space = Range.RegisterSpace;
+    Binding.Size = Range.NumDescriptors;
+
+    uint64_t LowRange =
+        combineUint32ToUint64(Binding.Space, Binding.LowerBound);
+    uint64_t HighRange = combineUint32ToUint64(
+        Binding.Space, Binding.LowerBound + Binding.Size - 1);
+
+    switch (Range.RangeType) {
+    case llvm::to_underlying(dxbc::DescriptorRangeType::CBV):
+      CRegBindingsMap.insert(LowRange, HighRange, Binding);
+      return;
+    case llvm::to_underlying(dxbc::DescriptorRangeType::SRV):
+      TRegBindingsMap.insert(LowRange, HighRange, Binding);
+      return;
+    case llvm::to_underlying(dxbc::DescriptorRangeType::UAV):
+      URegBindingsMap.insert(LowRange, HighRange, Binding);
+      return;
+    }
+    llvm_unreachable("Invalid Type in add Range Method");
+  }
+
+public:
+  RootSignatureBindingValidation()
+      : Allocator(), CRegBindingsMap(Allocator), TRegBindingsMap(Allocator),
+        URegBindingsMap(Allocator) {}
+
+  void addRsBindingInfo(mcdxbc::RootSignatureDesc &RSD,
+                        dxbc::ShaderVisibility Visibility);
+
+  bool checkCregBinding(dxil::ResourceInfo::ResourceBinding Binding);
+
+  bool checkTRegBinding(dxil::ResourceInfo::ResourceBinding Binding);
+
+  bool checkURegBinding(dxil::ResourceInfo::ResourceBinding Binding);
+};
+
 class DXILPostOptimizationValidation
     : public PassInfoMixin<DXILPostOptimizationValidation> {
 public:
diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.h b/llvm/lib/Target/DirectX/DXILRootSignature.h
index d0d5c7785bda3..ecfc577d1b97d 100644
--- a/llvm/lib/Target/DirectX/DXILRootSignature.h
+++ b/llvm/lib/Target/DirectX/DXILRootSignature.h
@@ -37,28 +37,30 @@ enum class RootSignatureElementKind {
 };
 
 class RootSignatureBindingInfo {
-  private:
-    SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> FuncToRsMap;
+private:
+  SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> FuncToRsMap;
 
-  public:
+public:
   using iterator =
-        SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc>::iterator;
+      SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc>::iterator;
 
-RootSignatureBindingInfo () = default;
-  RootSignatureBindingInfo(SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> Map) : FuncToRsMap(Map) {};
+  RootSignatureBindingInfo() = default;
+  RootSignatureBindingInfo(
+      SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> Map)
+      : FuncToRsMap(Map){};
 
   iterator find(const Function *F) { return FuncToRsMap.find(F); }
 
   iterator end() { return FuncToRsMap.end(); }
 
-  std::optional<mcdxbc::RootSignatureDesc> getDescForFunction(const Function *F) {
+  std::optional<mcdxbc::RootSignatureDesc>
+  getDescForFunction(const Function *F) {
     const auto FuncRs = find(F);
     if (FuncRs == end())
       return std::nullopt;
 
     return FuncRs->second;
   }
-
 };
 
 class RootSignatureAnalysis : public AnalysisInfoMixin<RootSignatureAnalysis> {
@@ -66,7 +68,7 @@ class RootSignatureAnalysis : public AnalysisInfoMixin<RootSignatureAnalysis> {
   static AnalysisKey Key;
 
 public:
-RootSignatureAnalysis() = default;
+  RootSignatureAnalysis() = default;
 
   using Result = RootSignatureBindingInfo;
 
@@ -86,8 +88,8 @@ class RootSignatureAnalysisWrapper : public ModulePass {
   using Result = RootSignatureBindingInfo;
 
   RootSignatureAnalysisWrapper() : ModulePass(ID) {}
-  
-  RootSignatureBindingInfo& getRSInfo() {return *FuncToRsMap;}
+
+  RootSignatureBindingInfo &getRSInfo() { return *FuncToRsMap; }
 
   bool runOnModule(Module &M) override;
 
diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable.ll
index b516d66180247..8e9b4b43b11a6 100644
--- a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable.ll
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable.ll
@@ -16,7 +16,7 @@ attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
 !2 = !{ ptr @main, !3, i32 2 } ; function, root signature
 !3 = !{ !5 } ; list of root signature elements
 !5 = !{ !"DescriptorTable", i32 0, !6, !7 }
-!6 = !{ !"SRV", i32 0, i32 1, i32 0, i32 -1, i32 4 }
+!6 = !{ !"SRV", i32 1, i32 1, i32 0, i32 -1, i32 4 }
 !7 = !{ !"UAV", i32 5, i32 1, i32 10, i32 5, i32 2 }
 
 ; DXC:  - Name:            RTS0
@@ -35,7 +35,7 @@ attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
 ; DXC-NEXT:            RangesOffset:    44
 ; DXC-NEXT:            Ranges:
 ; DXC-NEXT:              - RangeType:       0
-; DXC-NEXT:                NumDescriptors:  0
+; DXC-NEXT:                NumDescriptors:  1
 ; DXC-NEXT:                BaseShaderRegister: 1
 ; DXC-NEXT:                RegisterSpace:   0
 ; DXC-NEXT:                OffsetInDescriptorsFromTableStart: 4294967295

>From 8f40e83ab0db147e90070f15708d0a0f4e1a9d1f Mon Sep 17 00:00:00 2001
From: joaosaffran <joao.saffran at microsoft.com>
Date: Fri, 4 Jul 2025 19:24:25 +0000
Subject: [PATCH 6/8] finish implementing && fix tests

---
 .../DXILPostOptimizationValidation.cpp        | 45 +++++-----------
 .../DirectX/DXILPostOptimizationValidation.h  | 54 ++++++++++++++-----
 llvm/lib/Target/DirectX/DXILRootSignature.cpp |  5 +-
 ...criptorTable-AllValidFlagCombinationsV1.ll |  4 +-
 llvm/test/CodeGen/DirectX/llc-pipeline.ll     |  1 +
 5 files changed, 59 insertions(+), 50 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp
index 23bb5d1a7f651..a52a04323514c 100644
--- a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp
+++ b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp
@@ -153,23 +153,29 @@ static void reportErrors(Module &M, DXILResourceMap &DRM,
     RootSignatureBindingValidation Validation;
     Validation.addRsBindingInfo(*RSD, tripleToVisibility(MMI.ShaderProfile));
 
-    for (const auto &CBuf : DRM.cbuffers()) {
+    for (const ResourceInfo &CBuf : DRM.cbuffers()) {
       ResourceInfo::ResourceBinding Binding = CBuf.getBinding();
-      if (!Validation.checkCregBinding(Binding))
+      if (!Validation.checkCRegBinding(Binding))
         reportRegNotBound(M, "cbuffer", Binding);
     }
 
-    for (const auto &CBuf : DRM.srvs()) {
-      ResourceInfo::ResourceBinding Binding = CBuf.getBinding();
+    for (const ResourceInfo &SRV : DRM.srvs()) {
+      ResourceInfo::ResourceBinding Binding = SRV.getBinding();
       if (!Validation.checkTRegBinding(Binding))
         reportRegNotBound(M, "srv", Binding);
     }
 
-    for (const auto &CBuf : DRM.uavs()) {
-      ResourceInfo::ResourceBinding Binding = CBuf.getBinding();
+    for (const ResourceInfo &UAV : DRM.uavs()) {
+      ResourceInfo::ResourceBinding Binding = UAV.getBinding();
       if (!Validation.checkURegBinding(Binding))
         reportRegNotBound(M, "uav", Binding);
     }
+
+    for (const ResourceInfo &Sampler : DRM.samplers()) {
+      ResourceInfo::ResourceBinding Binding = Sampler.getBinding();
+      if (!Validation.checkSamplerBinding(Binding))
+        reportRegNotBound(M, "sampler", Binding);
+    }
   }
 }
 } // namespace
@@ -199,10 +205,6 @@ void RootSignatureBindingValidation::addRsBindingInfo(
           RSD.ParametersContainer.getDescriptorTable(Loc);
 
       for (const dxbc::RTS0::v2::DescriptorRange &Range : Table.Ranges) {
-        if (Range.RangeType ==
-            llvm::to_underlying(dxbc::DescriptorRangeType::Sampler))
-          continue;
-
         if (Header.ShaderVisibility ==
                 llvm::to_underlying(dxbc::ShaderVisibility::All) ||
             Header.ShaderVisibility == llvm::to_underlying(Visibility))
@@ -214,29 +216,6 @@ void RootSignatureBindingValidation::addRsBindingInfo(
   }
 }
 
-bool RootSignatureBindingValidation::checkCregBinding(
-    ResourceInfo::ResourceBinding Binding) {
-  return CRegBindingsMap.overlaps(
-      combineUint32ToUint64(Binding.Space, Binding.LowerBound),
-      combineUint32ToUint64(Binding.Space,
-                            Binding.LowerBound + Binding.Size - 1));
-}
-
-bool RootSignatureBindingValidation::checkTRegBinding(
-    ResourceInfo::ResourceBinding Binding) {
-  return TRegBindingsMap.overlaps(
-      combineUint32ToUint64(Binding.Space, Binding.LowerBound),
-      combineUint32ToUint64(Binding.Space, Binding.LowerBound + Binding.Size));
-}
-
-bool RootSignatureBindingValidation::checkURegBinding(
-    ResourceInfo::ResourceBinding Binding) {
-  return URegBindingsMap.overlaps(
-      combineUint32ToUint64(Binding.Space, Binding.LowerBound),
-      combineUint32ToUint64(Binding.Space,
-                            Binding.LowerBound + Binding.Size - 1));
-}
-
 PreservedAnalyses
 DXILPostOptimizationValidation::run(Module &M, ModuleAnalysisManager &MAM) {
   DXILResourceMap &DRM = MAM.getResult<DXILResourceAnalysis>(M);
diff --git a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h
index 58113bf9f93c7..0fa0285425d7e 100644
--- a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h
+++ b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h
@@ -36,12 +36,13 @@ class RootSignatureBindingValidation {
   MapT CRegBindingsMap;
   MapT TRegBindingsMap;
   MapT URegBindingsMap;
+  MapT SamplersBindingsMap;
 
   void addRange(const dxbc::RTS0::v2::RootDescriptor &Desc, uint32_t Type) {
     assert((Type == llvm::to_underlying(dxbc::RootParameterType::CBV) ||
             Type == llvm::to_underlying(dxbc::RootParameterType::SRV) ||
             Type == llvm::to_underlying(dxbc::RootParameterType::UAV)) &&
-           "Invalid Type");
+           "Invalid Type in add Range Method");
 
     llvm::dxil::ResourceInfo::ResourceBinding Binding;
     Binding.LowerBound = Desc.ShaderRegister;
@@ -53,19 +54,20 @@ class RootSignatureBindingValidation {
     uint64_t HighRange = combineUint32ToUint64(
         Binding.Space, Binding.LowerBound + Binding.Size - 1);
 
+    assert(LowRange <= HighRange && "Invalid range configuration");
+
     switch (Type) {
 
     case llvm::to_underlying(dxbc::RootParameterType::CBV):
       CRegBindingsMap.insert(LowRange, HighRange, Binding);
-      return;
+      break;
     case llvm::to_underlying(dxbc::RootParameterType::SRV):
       TRegBindingsMap.insert(LowRange, HighRange, Binding);
-      return;
+      break;
     case llvm::to_underlying(dxbc::RootParameterType::UAV):
       URegBindingsMap.insert(LowRange, HighRange, Binding);
-      return;
+      break;
     }
-    llvm_unreachable("Invalid Type in add Range Method");
   }
 
   void addRange(const dxbc::RTS0::v2::DescriptorRange &Range) {
@@ -80,33 +82,59 @@ class RootSignatureBindingValidation {
     uint64_t HighRange = combineUint32ToUint64(
         Binding.Space, Binding.LowerBound + Binding.Size - 1);
 
+    assert(LowRange <= HighRange && "Invalid range configuration");
+
     switch (Range.RangeType) {
     case llvm::to_underlying(dxbc::DescriptorRangeType::CBV):
       CRegBindingsMap.insert(LowRange, HighRange, Binding);
-      return;
+      break;
     case llvm::to_underlying(dxbc::DescriptorRangeType::SRV):
       TRegBindingsMap.insert(LowRange, HighRange, Binding);
-      return;
+      break;
     case llvm::to_underlying(dxbc::DescriptorRangeType::UAV):
       URegBindingsMap.insert(LowRange, HighRange, Binding);
-      return;
+      break;
+    case llvm::to_underlying(dxbc::DescriptorRangeType::Sampler):
+      SamplersBindingsMap.insert(LowRange, HighRange, Binding);
+      break;
     }
-    llvm_unreachable("Invalid Type in add Range Method");
   }
 
 public:
   RootSignatureBindingValidation()
       : Allocator(), CRegBindingsMap(Allocator), TRegBindingsMap(Allocator),
-        URegBindingsMap(Allocator) {}
+        URegBindingsMap(Allocator), SamplersBindingsMap(Allocator) {}
 
   void addRsBindingInfo(mcdxbc::RootSignatureDesc &RSD,
                         dxbc::ShaderVisibility Visibility);
 
-  bool checkCregBinding(dxil::ResourceInfo::ResourceBinding Binding);
+  bool checkCRegBinding(dxil::ResourceInfo::ResourceBinding Binding) {
+    return CRegBindingsMap.overlaps(
+        combineUint32ToUint64(Binding.Space, Binding.LowerBound),
+        combineUint32ToUint64(Binding.Space,
+                              Binding.LowerBound + Binding.Size - 1));
+  }
 
-  bool checkTRegBinding(dxil::ResourceInfo::ResourceBinding Binding);
+  bool checkTRegBinding(dxil::ResourceInfo::ResourceBinding Binding) {
+    return TRegBindingsMap.overlaps(
+        combineUint32ToUint64(Binding.Space, Binding.LowerBound),
+        combineUint32ToUint64(Binding.Space,
+                              Binding.LowerBound + Binding.Size - 1));
+  }
 
-  bool checkURegBinding(dxil::ResourceInfo::ResourceBinding Binding);
+  bool checkURegBinding(dxil::ResourceInfo::ResourceBinding Binding) {
+    return URegBindingsMap.overlaps(
+        combineUint32ToUint64(Binding.Space, Binding.LowerBound),
+        combineUint32ToUint64(Binding.Space,
+                              Binding.LowerBound + Binding.Size - 1));
+  }
+
+  bool checkSamplerBinding(dxil::ResourceInfo::ResourceBinding Binding) {
+    return SamplersBindingsMap.overlaps(
+        combineUint32ToUint64(Binding.Space, Binding.LowerBound),
+        combineUint32ToUint64(Binding.Space,
+                              Binding.LowerBound + Binding.Size - 1));
+  }
 };
 
 class DXILPostOptimizationValidation
diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
index 4094df160ef6f..2a68a4c324a09 100644
--- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp
+++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
@@ -635,8 +635,9 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
 
 //===----------------------------------------------------------------------===//
 bool RootSignatureAnalysisWrapper::runOnModule(Module &M) {
-  FuncToRsMap = std::make_unique<RootSignatureBindingInfo>(
-      RootSignatureBindingInfo(analyzeModule(M)));
+  if (!FuncToRsMap)
+    FuncToRsMap = std::make_unique<RootSignatureBindingInfo>(
+        RootSignatureBindingInfo(analyzeModule(M)));
   return false;
 }
 
diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-AllValidFlagCombinationsV1.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-AllValidFlagCombinationsV1.ll
index 9d89dbdd9107b..053721de1eb1f 100644
--- a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-AllValidFlagCombinationsV1.ll
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-AllValidFlagCombinationsV1.ll
@@ -13,7 +13,7 @@ attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
 !2 = !{ ptr @main, !3, i32 1 } ; function, root signature
 !3 = !{ !5 } ; list of root signature elements
 !5 = !{ !"DescriptorTable", i32 0, !6, !7 }
-!6 = !{ !"Sampler", i32 0, i32 1, i32 0, i32 -1, i32 1 }
+!6 = !{ !"Sampler", i32 1, i32 1, i32 0, i32 -1, i32 1 }
 !7 = !{ !"UAV", i32 5, i32 1, i32 10, i32 5, i32 3 }
 
 
@@ -33,7 +33,7 @@ attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
 ; DXC-NEXT:             RangesOffset:    44
 ; DXC-NEXT:             Ranges:
 ; DXC-NEXT:               - RangeType:       3
-; DXC-NEXT:                 NumDescriptors:  0
+; DXC-NEXT:                 NumDescriptors:  1
 ; DXC-NEXT:                 BaseShaderRegister: 1
 ; DXC-NEXT:                 RegisterSpace:   0
 ; DXC-NEXT:                 OffsetInDescriptorsFromTableStart: 4294967295
diff --git a/llvm/test/CodeGen/DirectX/llc-pipeline.ll b/llvm/test/CodeGen/DirectX/llc-pipeline.ll
index 2b29fd30a7a56..8d75249dc6ecb 100644
--- a/llvm/test/CodeGen/DirectX/llc-pipeline.ll
+++ b/llvm/test/CodeGen/DirectX/llc-pipeline.ll
@@ -31,6 +31,7 @@
 ; CHECK-NEXT:   DXIL Module Metadata analysis
 ; CHECK-NEXT:   DXIL Shader Flag Analysis
 ; CHECK-NEXT:   DXIL Translate Metadata
+; CHECK-NEXT:   DXIL Root Signature Analysis
 ; CHECK-NEXT:   DXIL Post Optimization Validation
 ; CHECK-NEXT:   DXIL Op Lowering
 ; CHECK-NEXT:   DXIL Prepare Module

>From 28350b2dfe2a896b2199260953c1d061550badba Mon Sep 17 00:00:00 2001
From: joaosaffran <joao.saffran at microsoft.com>
Date: Sat, 5 Jul 2025 00:35:07 +0000
Subject: [PATCH 7/8] fix issue

---
 llvm/lib/Target/DirectX/DXContainerGlobals.cpp | 2 --
 1 file changed, 2 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
index 5c763c24a210a..6c8ae8eaaea77 100644
--- a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
+++ b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
@@ -160,11 +160,9 @@ void DXContainerGlobals::addRootSignature(Module &M,
 
   assert(MMI.EntryPropertyVec.size() == 1);
 
-  auto &RSA = getAnalysis<RootSignatureAnalysisWrapper>().getRSInfo();
   auto &RSA = getAnalysis<RootSignatureAnalysisWrapper>().getRSInfo();
   const Function *EntryFunction = MMI.EntryPropertyVec[0].Entry;
   const auto &RS = RSA.getDescForFunction(EntryFunction);
-  const auto &RS = RSA.getDescForFunction(EntryFunction);
 
   if (!RS)
     return;

>From 4fd2e0bfdda5f87a3204982aa01d8b158945eadf Mon Sep 17 00:00:00 2001
From: joaosaffran <joao.saffran at microsoft.com>
Date: Sat, 5 Jul 2025 00:37:12 +0000
Subject: [PATCH 8/8] sync parent

---
 llvm/lib/Target/DirectX/DXILRootSignature.h | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.h b/llvm/lib/Target/DirectX/DXILRootSignature.h
index ecfc577d1b97d..3832182277050 100644
--- a/llvm/lib/Target/DirectX/DXILRootSignature.h
+++ b/llvm/lib/Target/DirectX/DXILRootSignature.h
@@ -47,7 +47,7 @@ class RootSignatureBindingInfo {
   RootSignatureBindingInfo() = default;
   RootSignatureBindingInfo(
       SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> Map)
-      : FuncToRsMap(Map){};
+      : FuncToRsMap(Map) {};
 
   iterator find(const Function *F) { return FuncToRsMap.find(F); }
 
@@ -72,7 +72,10 @@ class RootSignatureAnalysis : public AnalysisInfoMixin<RootSignatureAnalysis> {
 
   using Result = RootSignatureBindingInfo;
 
-  RootSignatureBindingInfo run(Module &M, ModuleAnalysisManager &AM);
+  Result run(Module &M, ModuleAnalysisManager &AM);
+
+private:
+  std::unique_ptr<RootSignatureBindingInfo> AnalysisResult;
 };
 
 /// Wrapper pass for the legacy pass manager.



More information about the llvm-branch-commits mailing list