[llvm] [DirectX] generate resource table for PSV part (PR #106607)

via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 29 11:46:10 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-directx

Author: Xiang Li (python3kgae)

<details>
<summary>Changes</summary>

Use DXILResourceWrapperPass to build the resource table.

Since DXILResourceWrapperPass operates on LLVM intrinsics rather than DXIL operations, add addPreserved for DXILResourceWrapperPass in the passes before DXContainerGlobals

Fixes #<!-- -->103275

---
Full diff: https://github.com/llvm/llvm-project/pull/106607.diff


7 Files Affected:

- (modified) llvm/lib/Target/DirectX/DXContainerGlobals.cpp (+56) 
- (modified) llvm/lib/Target/DirectX/DXILFinalizeLinkage.cpp (+6) 
- (modified) llvm/lib/Target/DirectX/DXILFinalizeLinkage.h (+1) 
- (modified) llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp (+6) 
- (modified) llvm/lib/Target/DirectX/DXILIntrinsicExpansion.h (+1) 
- (modified) llvm/lib/Target/DirectX/DXILPrepare.cpp (+2) 
- (added) llvm/test/CodeGen/DirectX/ContainerData/PSVResources.ll (+87) 


``````````diff
diff --git a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
index d47b9c7a25b8fe..e91af93ec1ed3e 100644
--- a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
+++ b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
@@ -15,6 +15,7 @@
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringRef.h"
+#include "llvm/Analysis/DXILResource.h"
 #include "llvm/BinaryFormat/DXContainer.h"
 #include "llvm/CodeGen/Passes.h"
 #include "llvm/IR/Constants.h"
@@ -39,6 +40,7 @@ class DXContainerGlobals : public llvm::ModulePass {
   GlobalVariable *buildSignature(Module &M, Signature &Sig, StringRef Name,
                                  StringRef SectionName);
   void addSignature(Module &M, SmallVector<GlobalValue *> &Globals);
+  void addResourcesForPSV(Module &M, PSVRuntimeInfo &PSV);
   void addPipelineStateValidationInfo(Module &M,
                                       SmallVector<GlobalValue *> &Globals);
 
@@ -57,6 +59,7 @@ class DXContainerGlobals : public llvm::ModulePass {
   void getAnalysisUsage(AnalysisUsage &AU) const override {
     AU.setPreservesAll();
     AU.addRequired<ShaderFlagsAnalysisWrapper>();
+    AU.addRequired<DXILResourceWrapperPass>();
   }
 };
 
@@ -138,6 +141,56 @@ void DXContainerGlobals::addSignature(Module &M,
   Globals.emplace_back(buildSignature(M, OutputSig, "dx.osg1", "OSG1"));
 }
 
+void DXContainerGlobals::addResourcesForPSV(Module &M, PSVRuntimeInfo &PSV) {
+  const DXILResourceMap &ResMap =
+      getAnalysis<DXILResourceWrapperPass>().getResourceMap();
+
+  for (const dxil::ResourceInfo &ResInfo : ResMap) {
+    const dxil::ResourceInfo::ResourceBinding &Binding = ResInfo.getBinding();
+    dxbc::PSV::v2::ResourceBindInfo BindInfo;
+    BindInfo.LowerBound = Binding.LowerBound;
+    BindInfo.UpperBound = Binding.LowerBound + Binding.Size - 1;
+    BindInfo.Space = Binding.Space;
+
+    dxbc::PSV::ResourceType ResType = dxbc::PSV::ResourceType::Invalid;
+    bool IsUAV = ResInfo.getResourceClass() == dxil::ResourceClass::UAV;
+    switch (ResInfo.getResourceKind()) {
+    case dxil::ResourceKind::Sampler:
+      ResType = dxbc::PSV::ResourceType::Sampler;
+      break;
+    case dxil::ResourceKind::CBuffer:
+      ResType = dxbc::PSV::ResourceType::CBV;
+      break;
+    case dxil::ResourceKind::StructuredBuffer:
+      ResType = IsUAV ? dxbc::PSV::ResourceType::UAVStructured
+                      : dxbc::PSV::ResourceType::SRVStructured;
+      if (IsUAV && ResInfo.getUAV().HasCounter)
+        ResType = dxbc::PSV::ResourceType::UAVStructuredWithCounter;
+      break;
+    case dxil::ResourceKind::RTAccelerationStructure:
+      ResType = dxbc::PSV::ResourceType::SRVRaw;
+      break;
+    case dxil::ResourceKind::RawBuffer:
+      ResType = IsUAV ? dxbc::PSV::ResourceType::UAVRaw
+                      : dxbc::PSV::ResourceType::SRVRaw;
+      break;
+    default:
+      ResType = IsUAV ? dxbc::PSV::ResourceType::UAVTyped
+                      : dxbc::PSV::ResourceType::SRVTyped;
+      break;
+    }
+    BindInfo.Type = ResType;
+
+    BindInfo.Kind =
+        static_cast<dxbc::PSV::ResourceKind>(ResInfo.getResourceKind());
+    // TODO: Add support for dxbc::PSV::ResourceFlag::UsedByAtomic64, tracking
+    // with https://github.com/llvm/llvm-project/issues/104392
+    BindInfo.Flags = 0u;
+
+    PSV.Resources.emplace_back(BindInfo);
+  }
+}
+
 void DXContainerGlobals::addPipelineStateValidationInfo(
     Module &M, SmallVector<GlobalValue *> &Globals) {
   SmallString<256> Data;
@@ -149,6 +202,8 @@ void DXContainerGlobals::addPipelineStateValidationInfo(
   PSV.BaseData.ShaderStage =
       static_cast<uint8_t>(TT.getEnvironment() - Triple::Pixel);
 
+  addResourcesForPSV(M, PSV);
+
   // Hardcoded values here to unblock loading the shader into D3D.
   //
   // TODO: Lots more stuff to do here!
@@ -170,6 +225,7 @@ char DXContainerGlobals::ID = 0;
 INITIALIZE_PASS_BEGIN(DXContainerGlobals, "dxil-globals",
                       "DXContainer Global Emitter", false, true)
 INITIALIZE_PASS_DEPENDENCY(ShaderFlagsAnalysisWrapper)
+INITIALIZE_PASS_DEPENDENCY(DXILResourceWrapperPass)
 INITIALIZE_PASS_END(DXContainerGlobals, "dxil-globals",
                     "DXContainer Global Emitter", false, true)
 
diff --git a/llvm/lib/Target/DirectX/DXILFinalizeLinkage.cpp b/llvm/lib/Target/DirectX/DXILFinalizeLinkage.cpp
index c02eb768cdf49b..2143e6840f46ed 100644
--- a/llvm/lib/Target/DirectX/DXILFinalizeLinkage.cpp
+++ b/llvm/lib/Target/DirectX/DXILFinalizeLinkage.cpp
@@ -8,6 +8,7 @@
 
 #include "DXILFinalizeLinkage.h"
 #include "DirectX.h"
+#include "llvm/Analysis/DXILResource.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/GlobalValue.h"
 #include "llvm/IR/Metadata.h"
@@ -48,6 +49,11 @@ bool DXILFinalizeLinkageLegacy::runOnModule(Module &M) {
   return finalizeLinkage(M);
 }
 
+void DXILFinalizeLinkageLegacy::getAnalysisUsage(
+    AnalysisUsage &AU) const {
+  AU.addPreserved<DXILResourceWrapperPass>();
+}
+
 char DXILFinalizeLinkageLegacy::ID = 0;
 
 INITIALIZE_PASS_BEGIN(DXILFinalizeLinkageLegacy, DEBUG_TYPE,
diff --git a/llvm/lib/Target/DirectX/DXILFinalizeLinkage.h b/llvm/lib/Target/DirectX/DXILFinalizeLinkage.h
index aab1bc3f7a28e2..62d3a8a27cfced 100644
--- a/llvm/lib/Target/DirectX/DXILFinalizeLinkage.h
+++ b/llvm/lib/Target/DirectX/DXILFinalizeLinkage.h
@@ -32,6 +32,7 @@ class DXILFinalizeLinkageLegacy : public ModulePass {
   DXILFinalizeLinkageLegacy() : ModulePass(ID) {}
   bool runOnModule(Module &M) override;
 
+  void getAnalysisUsage(AnalysisUsage &AU) const override;
   static char ID; // Pass identification.
 };
 } // namespace llvm
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index 2daa4f825c3b25..d4030e484279a9 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -14,6 +14,7 @@
 #include "DirectX.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Analysis/DXILResource.h"
 #include "llvm/CodeGen/Passes.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Instruction.h"
@@ -440,6 +441,11 @@ bool DXILIntrinsicExpansionLegacy::runOnModule(Module &M) {
   return expansionIntrinsics(M);
 }
 
+void DXILIntrinsicExpansionLegacy::getAnalysisUsage(
+    AnalysisUsage &AU) const {
+  AU.addPreserved<DXILResourceWrapperPass>();
+}
+
 char DXILIntrinsicExpansionLegacy::ID = 0;
 
 INITIALIZE_PASS_BEGIN(DXILIntrinsicExpansionLegacy, DEBUG_TYPE,
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.h b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.h
index c86681af7a3712..c8ee4b1b934b2d 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.h
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.h
@@ -26,6 +26,7 @@ class DXILIntrinsicExpansionLegacy : public ModulePass {
   bool runOnModule(Module &M) override;
   DXILIntrinsicExpansionLegacy() : ModulePass(ID) {}
 
+  void getAnalysisUsage(AnalysisUsage &AU) const override;
   static char ID; // Pass identification.
 };
 } // namespace llvm
diff --git a/llvm/lib/Target/DirectX/DXILPrepare.cpp b/llvm/lib/Target/DirectX/DXILPrepare.cpp
index 56098864e987fb..61a4a589d79a59 100644
--- a/llvm/lib/Target/DirectX/DXILPrepare.cpp
+++ b/llvm/lib/Target/DirectX/DXILPrepare.cpp
@@ -19,6 +19,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringSet.h"
+#include "llvm/Analysis/DXILResource.h"
 #include "llvm/CodeGen/Passes.h"
 #include "llvm/IR/AttributeMask.h"
 #include "llvm/IR/IRBuilder.h"
@@ -247,6 +248,7 @@ class DXILPrepareModule : public ModulePass {
   void getAnalysisUsage(AnalysisUsage &AU) const override {
     AU.addPreserved<ShaderFlagsAnalysisWrapper>();
     AU.addPreserved<DXILResourceMDWrapper>();
+    AU.addPreserved<DXILResourceWrapperPass>();
   }
   static char ID; // Pass identification.
 };
diff --git a/llvm/test/CodeGen/DirectX/ContainerData/PSVResources.ll b/llvm/test/CodeGen/DirectX/ContainerData/PSVResources.ll
new file mode 100644
index 00000000000000..3932ca4076f6f7
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ContainerData/PSVResources.ll
@@ -0,0 +1,87 @@
+; RUN: llc %s --filetype=obj -o - | obj2yaml | FileCheck %s --check-prefix=DXC
+
+; Make sure resource table is created correctly.
+; DXC: Resources:
+target triple = "dxil-unknown-shadermodel6.0-compute"
+
+define void @main() #0 {
+
+  ; ByteAddressBuffer Buf : register(t8, space1)
+; DXC:        - Type:            SRVRaw
+; DXC:          Space:           1
+; DXC:          LowerBound:      8
+; DXC:          UpperBound:      8
+; DXC:          Kind:            RawBuffer
+; DXC:          Flags:           0
+  %srv0 = call target("dx.RawBuffer", i8, 0, 0)
+      @llvm.dx.handle.fromBinding.tdx.RawBuffer_i8_0_0t(
+          i32 1, i32 8, i32 1, i32 0, i1 false)
+
+  ; struct S { float4 a; uint4 b; };
+  ; StructuredBuffer<S> Buf : register(t2, space4)
+; DXC:        - Type:            SRVStructured
+; DXC:          Space:           4
+; DXC:          LowerBound:      2
+; DXC:          UpperBound:      2
+; DXC:          Kind:            StructuredBuffer
+; DXC:          Flags:           0
+  %srv1 = call target("dx.RawBuffer", {<4 x float>, <4 x i32>}, 0, 0)
+      @llvm.dx.handle.fromBinding.tdx.RawBuffer_sl_v4f32v4i32s_0_0t(
+          i32 4, i32 2, i32 1, i32 0, i1 false)
+
+  ; Buffer<uint4> Buf[24] : register(t3, space5)
+; DXC:        - Type:            SRVTyped
+; DXC:          Space:           5
+; DXC:          LowerBound:      3
+; DXC:          UpperBound:      26
+; DXC:          Kind:            TypedBuffer
+; DXC:          Flags:           0
+  %srv2 = call target("dx.TypedBuffer", <4 x i32>, 0, 0, 0)
+      @llvm.dx.handle.fromBinding.tdx.TypedBuffer_i32_0_0t(
+          i32 5, i32 3, i32 24, i32 0, i1 false)
+
+  ; RWBuffer<int> Buf : register(u7, space2)
+; DXC:        - Type:            UAVTyped
+; DXC:          Space:           2
+; DXC:          LowerBound:      7
+; DXC:          UpperBound:      7
+; DXC:          Kind:            TypedBuffer
+; DXC:          Flags:           0
+  %uav0 = call target("dx.TypedBuffer", i32, 1, 0, 1)
+      @llvm.dx.handle.fromBinding.tdx.TypedBuffer_i32_1_0t(
+          i32 2, i32 7, i32 1, i32 0, i1 false)
+
+  ; RWBuffer<float4> Buf : register(u5, space3)
+; DXC:        - Type:            UAVTyped
+; DXC:          Space:           3
+; DXC:          LowerBound:      5
+; DXC:          UpperBound:      5
+; DXC:          Kind:            TypedBuffer
+; DXC:          Flags:           0
+  %uav1 = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0)
+              @llvm.dx.handle.fromBinding.tdx.TypedBuffer_f32_1_0(
+                  i32 3, i32 5, i32 1, i32 0, i1 false)
+
+  ; RWBuffer<float4> BufferArray[10] : register(u0, space4)
+; DXC:        - Type:            UAVTyped
+; DXC:          Space:           4
+; DXC:          LowerBound:      0
+; DXC:          UpperBound:      9
+; DXC:          Kind:            TypedBuffer
+; DXC:          Flags:           0
+  ; RWBuffer<float4> Buf = BufferArray[0]
+  %uav2_1 = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0)
+              @llvm.dx.handle.fromBinding.tdx.TypedBuffer_f32_1_0(
+                  i32 4, i32 0, i32 10, i32 0, i1 false)
+  ; RWBuffer<float4> Buf = BufferArray[5]
+  %uav2_2 = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0)
+              @llvm.dx.handle.fromBinding.tdx.TypedBuffer_f32_1_0(
+                  i32 4, i32 0, i32 10, i32 5, i1 false)
+  ret void
+}
+
+attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
+
+!dx.valver = !{!0}
+
+!0 = !{i32 1, i32 7}

``````````

</details>


https://github.com/llvm/llvm-project/pull/106607


More information about the llvm-commits mailing list