[llvm] eb2929d - [DirectX] use DXILMetadataAnalysis to build PSVRuntimeInfo (#107101)

via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 4 18:59:46 PDT 2024


Author: Xiang Li
Date: 2024-09-04T21:59:42-04:00
New Revision: eb2929d323c0c44f2037cf8a345ca6984ec228eb

URL: https://github.com/llvm/llvm-project/commit/eb2929d323c0c44f2037cf8a345ca6984ec228eb
DIFF: https://github.com/llvm/llvm-project/commit/eb2929d323c0c44f2037cf8a345ca6984ec228eb.diff

LOG: [DirectX] use DXILMetadataAnalysis to build PSVRuntimeInfo (#107101)

Replace the hardcoded values for compute shader in
DXContainer::addPipelineStateValidationInfo.
Still missing wave size.

Add preserved for previous passes so the information is not lost.

Fix https://github.com/llvm/wg-hlsl/issues/51

Added: 
    llvm/test/CodeGen/DirectX/ContainerData/RuntimeInfoCS.ll

Modified: 
    llvm/lib/Target/DirectX/DXContainerGlobals.cpp
    llvm/lib/Target/DirectX/DXILPrepare.cpp
    llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
index d47b9c7a25b8fe..aa7769899ff270 100644
--- a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
+++ b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
@@ -15,6 +15,7 @@
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringRef.h"
+#include "llvm/Analysis/DXILMetadataAnalysis.h"
 #include "llvm/BinaryFormat/DXContainer.h"
 #include "llvm/CodeGen/Passes.h"
 #include "llvm/IR/Constants.h"
@@ -57,6 +58,7 @@ class DXContainerGlobals : public llvm::ModulePass {
   void getAnalysisUsage(AnalysisUsage &AU) const override {
     AU.setPreservesAll();
     AU.addRequired<ShaderFlagsAnalysisWrapper>();
+    AU.addRequired<DXILMetadataAnalysisWrapperPass>();
   }
 };
 
@@ -143,23 +145,35 @@ void DXContainerGlobals::addPipelineStateValidationInfo(
   SmallString<256> Data;
   raw_svector_ostream OS(Data);
   PSVRuntimeInfo PSV;
-  Triple TT(M.getTargetTriple());
   PSV.BaseData.MinimumWaveLaneCount = 0;
   PSV.BaseData.MaximumWaveLaneCount = std::numeric_limits<uint32_t>::max();
+
+  dxil::ModuleMetadataInfo &MMI =
+      getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
+  assert(MMI.EntryPropertyVec.size() == 1 ||
+         MMI.ShaderStage == Triple::Library);
   PSV.BaseData.ShaderStage =
-      static_cast<uint8_t>(TT.getEnvironment() - Triple::Pixel);
+      static_cast<uint8_t>(MMI.ShaderStage - Triple::Pixel);
 
   // Hardcoded values here to unblock loading the shader into D3D.
   //
   // TODO: Lots more stuff to do here!
   //
   // See issue https://github.com/llvm/llvm-project/issues/96674.
-  PSV.BaseData.NumThreadsX = 1;
-  PSV.BaseData.NumThreadsY = 1;
-  PSV.BaseData.NumThreadsZ = 1;
-  PSV.EntryName = "main";
+  switch (MMI.ShaderStage) {
+  case Triple::Compute:
+    PSV.BaseData.NumThreadsX = MMI.EntryPropertyVec[0].NumThreadsX;
+    PSV.BaseData.NumThreadsY = MMI.EntryPropertyVec[0].NumThreadsY;
+    PSV.BaseData.NumThreadsZ = MMI.EntryPropertyVec[0].NumThreadsZ;
+    break;
+  default:
+    break;
+  }
+
+  if (MMI.ShaderStage != Triple::Library)
+    PSV.EntryName = MMI.EntryPropertyVec[0].Entry->getName();
 
-  PSV.finalize(TT.getEnvironment());
+  PSV.finalize(MMI.ShaderStage);
   PSV.write(OS);
   Constant *Constant =
       ConstantDataArray::getString(M.getContext(), Data, /*AddNull*/ false);
@@ -170,6 +184,7 @@ char DXContainerGlobals::ID = 0;
 INITIALIZE_PASS_BEGIN(DXContainerGlobals, "dxil-globals",
                       "DXContainer Global Emitter", false, true)
 INITIALIZE_PASS_DEPENDENCY(ShaderFlagsAnalysisWrapper)
+INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass)
 INITIALIZE_PASS_END(DXContainerGlobals, "dxil-globals",
                     "DXContainer Global Emitter", false, true)
 

diff  --git a/llvm/lib/Target/DirectX/DXILPrepare.cpp b/llvm/lib/Target/DirectX/DXILPrepare.cpp
index 56098864e987fb..f6b7355b936255 100644
--- a/llvm/lib/Target/DirectX/DXILPrepare.cpp
+++ b/llvm/lib/Target/DirectX/DXILPrepare.cpp
@@ -19,6 +19,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringSet.h"
+#include "llvm/Analysis/DXILMetadataAnalysis.h"
 #include "llvm/CodeGen/Passes.h"
 #include "llvm/IR/AttributeMask.h"
 #include "llvm/IR/IRBuilder.h"
@@ -247,6 +248,7 @@ class DXILPrepareModule : public ModulePass {
   void getAnalysisUsage(AnalysisUsage &AU) const override {
     AU.addPreserved<ShaderFlagsAnalysisWrapper>();
     AU.addPreserved<DXILResourceMDWrapper>();
+    AU.addPreserved<DXILMetadataAnalysisWrapperPass>();
   }
   static char ID; // Pass identification.
 };

diff  --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
index 2c6d20112060df..11cd9df1d1dc42 100644
--- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
+++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
@@ -13,6 +13,7 @@
 #include "DXILShaderFlags.h"
 #include "DirectX.h"
 #include "llvm/ADT/StringSet.h"
+#include "llvm/Analysis/DXILMetadataAnalysis.h"
 #include "llvm/Analysis/DXILResource.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/Metadata.h"
@@ -103,6 +104,7 @@ class DXILTranslateMetadataLegacy : public ModulePass {
     AU.addRequired<DXILResourceWrapperPass>();
     AU.addRequired<DXILResourceMDWrapper>();
     AU.addRequired<ShaderFlagsAnalysisWrapper>();
+    AU.addRequired<DXILMetadataAnalysisWrapperPass>();
   }
 
   bool runOnModule(Module &M) override {

diff  --git a/llvm/test/CodeGen/DirectX/ContainerData/RuntimeInfoCS.ll b/llvm/test/CodeGen/DirectX/ContainerData/RuntimeInfoCS.ll
new file mode 100644
index 00000000000000..595e70092bb081
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RuntimeInfoCS.ll
@@ -0,0 +1,41 @@
+; RUN: opt %s -dxil-embed -dxil-globals -S -o - | FileCheck %s
+; RUN: llc %s --filetype=obj -o - | obj2yaml | FileCheck %s --check-prefix=DXC
+target triple = "dxil-unknown-shadermodel6.0-compute"
+
+; CHECK: @dx.psv0 = private constant [80 x i8] c"{{.*}}", section "PSV0", align 4
+
+define void @cs_main() #0 {
+entry:
+  ret void
+}
+
+attributes #0 = { "hlsl.numthreads"="8,8,1" "hlsl.shader"="compute" }
+
+!dx.valver = !{!0}
+
+!0 = !{i32 1, i32 7}
+
+; DXC: - Name:            PSV0
+; DXC-NEXT:   Size:            80
+; DXC-NEXT:    PSVInfo:
+; DXC-NEXT:      Version:         3
+; DXC-NEXT:      ShaderStage:     5
+; DXC-NEXT:      MinimumWaveLaneCount: 0
+; DXC-NEXT:      MaximumWaveLaneCount: 4294967295
+; DXC-NEXT:      UsesViewID:      0
+; DXC-NEXT:      SigInputVectors: 0
+; DXC-NEXT:      SigOutputVectors: [ 0, 0, 0, 0 ]
+; DXC-NEXT:      NumThreadsX:     8
+; DXC-NEXT:      NumThreadsY:     8
+; DXC-NEXT:      NumThreadsZ:     1
+; DXC-NEXT:      EntryName:       cs_main
+; DXC-NEXT:      ResourceStride:  24
+; DXC-NEXT:      Resources:       []
+; DXC-NEXT:      SigInputElements: []
+; DXC-NEXT:      SigOutputElements: []
+; DXC-NEXT:      SigPatchOrPrimElements: []
+; DXC-NEXT:      InputOutputMap:
+; DXC-NEXT:        - [  ]
+; DXC-NEXT:        - [  ]
+; DXC-NEXT:        - [  ]
+; DXC-NEXT:        - [  ]


        


More information about the llvm-commits mailing list