[llvm-branch-commits] [DirectX] Add resource handling to the DXIL pretty printer (PR #104448)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Aug 15 07:52:17 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-analysis

Author: Justin Bogner (bogner)

<details>
<summary>Changes</summary>

Handle target extension type resources when printing resources to textual IR.


---
Full diff: https://github.com/llvm/llvm-project/pull/104448.diff


5 Files Affected:

- (modified) llvm/include/llvm/Analysis/DXILResource.h (+15-1) 
- (modified) llvm/lib/Target/DirectX/DXILOpLowering.cpp (+2-2) 
- (modified) llvm/lib/Target/DirectX/DXILPrettyPrinter.cpp (+222-5) 
- (modified) llvm/test/CodeGen/DirectX/CreateHandle.ll (+9) 
- (modified) llvm/test/CodeGen/DirectX/CreateHandleFromBinding.ll (+9) 


``````````diff
diff --git a/llvm/include/llvm/Analysis/DXILResource.h b/llvm/include/llvm/Analysis/DXILResource.h
index 2ed508b28a908..faee9f5dac1b4 100644
--- a/llvm/include/llvm/Analysis/DXILResource.h
+++ b/llvm/include/llvm/Analysis/DXILResource.h
@@ -142,12 +142,17 @@ class ResourceInfo {
     Binding.LowerBound = LowerBound;
     Binding.Size = Size;
   }
+  const ResourceBinding &getBinding() const { return Binding; }
   void setUAV(bool GloballyCoherent, bool HasCounter, bool IsROV) {
     assert(isUAV() && "Not a UAV");
     UAVFlags.GloballyCoherent = GloballyCoherent;
     UAVFlags.HasCounter = HasCounter;
     UAVFlags.IsROV = IsROV;
   }
+  const UAVInfo &getUAV() const {
+    assert(isUAV() && "Not a UAV");
+    return UAVFlags;
+  }
   void setCBuffer(uint32_t Size) {
     assert(isCBuffer() && "Not a CBuffer");
     CBufferSize = Size;
@@ -163,6 +168,10 @@ class ResourceInfo {
     Typed.ElementTy = ElementTy;
     Typed.ElementCount = ElementCount;
   }
+  const TypedInfo &getTyped() const {
+    assert(isTyped() && "Not typed");
+    return Typed;
+  }
   void setFeedback(dxil::SamplerFeedbackType Type) {
     assert(isFeedback() && "Not Feedback");
     Feedback.Type = Type;
@@ -171,8 +180,14 @@ class ResourceInfo {
     assert(isMultiSample() && "Not MultiSampled");
     MultiSample.Count = Count;
   }
+  const MSInfo &getMultiSample() const {
+    assert(isMultiSample() && "Not MultiSampled");
+    return MultiSample;
+  }
 
+  StringRef getName() const { return Name; }
   dxil::ResourceClass getResourceClass() const { return RC; }
+  dxil::ResourceKind getResourceKind() const { return Kind; }
 
   bool operator==(const ResourceInfo &RHS) const;
 
@@ -222,7 +237,6 @@ class ResourceInfo {
 
   MDTuple *getAsMetadata(LLVMContext &Ctx) const;
 
-  ResourceBinding getBinding() const { return Binding; }
   std::pair<uint32_t, uint32_t> getAnnotateProps() const;
 
   void print(raw_ostream &OS) const;
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index f34302cc95065..e7c36ead1cc34 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -175,7 +175,7 @@ class OpLowerer {
       IRB.SetInsertPoint(CI);
 
       dxil::ResourceInfo &RI = DRM[CI];
-      dxil::ResourceInfo::ResourceBinding Binding = RI.getBinding();
+      const auto &Binding = RI.getBinding();
 
       std::array<Value *, 4> Args{
           ConstantInt::get(Int8Ty, llvm::to_underlying(RI.getResourceClass())),
@@ -201,7 +201,7 @@ class OpLowerer {
       IRB.SetInsertPoint(CI);
 
       dxil::ResourceInfo &RI = DRM[CI];
-      dxil::ResourceInfo::ResourceBinding Binding = RI.getBinding();
+      const auto &Binding = RI.getBinding();
       std::pair<uint32_t, uint32_t> Props = RI.getAnnotateProps();
 
       Constant *ResBind = OpBuilder.getResBind(
diff --git a/llvm/lib/Target/DirectX/DXILPrettyPrinter.cpp b/llvm/lib/Target/DirectX/DXILPrettyPrinter.cpp
index c57631cc4c8b6..76a40dbfc5845 100644
--- a/llvm/lib/Target/DirectX/DXILPrettyPrinter.cpp
+++ b/llvm/lib/Target/DirectX/DXILPrettyPrinter.cpp
@@ -10,23 +10,235 @@
 #include "DXILResourceAnalysis.h"
 #include "DirectX.h"
 #include "llvm/ADT/StringRef.h"
+#include "llvm/Analysis/DXILResource.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/InitializePasses.h"
 #include "llvm/Pass.h"
+#include "llvm/Support/FormatAdapters.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/raw_ostream.h"
 
 using namespace llvm;
 
-static void prettyPrintResources(raw_ostream &OS,
+static constexpr StringRef getRCName(dxil::ResourceClass RC) {
+  switch (RC) {
+  case dxil::ResourceClass::SRV:
+    return "SRV";
+  case dxil::ResourceClass::UAV:
+    return "UAV";
+  case dxil::ResourceClass::CBuffer:
+    return "cbuffer";
+  case dxil::ResourceClass::Sampler:
+    return "sampler";
+  }
+  llvm_unreachable("covered switch");
+}
+
+static constexpr StringRef getRCPrefix(dxil::ResourceClass RC) {
+  switch (RC) {
+  case dxil::ResourceClass::SRV:
+    return "t";
+  case dxil::ResourceClass::UAV:
+    return "u";
+  case dxil::ResourceClass::CBuffer:
+    return "cb";
+  case dxil::ResourceClass::Sampler:
+    return "s";
+  }
+}
+
+static constexpr StringRef getFormatName(const dxil::ResourceInfo &RI) {
+  if (RI.isTyped()) {
+    switch (RI.getTyped().ElementTy) {
+    case dxil::ElementType::I1:
+      return "i1";
+    case dxil::ElementType::I16:
+      return "i16";
+    case dxil::ElementType::U16:
+      return "u16";
+    case dxil::ElementType::I32:
+      return "i32";
+    case dxil::ElementType::U32:
+      return "u32";
+    case dxil::ElementType::I64:
+      return "i64";
+    case dxil::ElementType::U64:
+      return "u64";
+    case dxil::ElementType::F16:
+      return "f16";
+    case dxil::ElementType::F32:
+      return "f32";
+    case dxil::ElementType::F64:
+      return "f64";
+    case dxil::ElementType::SNormF16:
+      return "snorm_f16";
+    case dxil::ElementType::UNormF16:
+      return "unorm_f16";
+    case dxil::ElementType::SNormF32:
+      return "snorm_f32";
+    case dxil::ElementType::UNormF32:
+      return "unorm_f32";
+    case dxil::ElementType::SNormF64:
+      return "snorm_f64";
+    case dxil::ElementType::UNormF64:
+      return "unorm_f64";
+    case dxil::ElementType::PackedS8x32:
+      return "p32i8";
+    case dxil::ElementType::PackedU8x32:
+      return "p32u8";
+    case dxil::ElementType::Invalid:
+      llvm_unreachable("Invalid ElementType");
+    }
+    llvm_unreachable("Unhandled ElementType");
+  } else if (RI.isStruct())
+    return "struct";
+  else if (RI.isCBuffer() || RI.isSampler())
+    return "NA";
+  return "byte";
+}
+
+static constexpr StringRef getTextureDimName(dxil::ResourceKind RK) {
+  switch (RK) {
+  case dxil::ResourceKind::Texture1D:
+    return "1d";
+  case dxil::ResourceKind::Texture2D:
+    return "2d";
+  case dxil::ResourceKind::Texture3D:
+    return "3d";
+  case dxil::ResourceKind::TextureCube:
+    return "cube";
+  case dxil::ResourceKind::Texture1DArray:
+    return "1darray";
+  case dxil::ResourceKind::Texture2DArray:
+    return "2darray";
+  case dxil::ResourceKind::TextureCubeArray:
+    return "cubearray";
+  case dxil::ResourceKind::TBuffer:
+    return "tbuffer";
+  case dxil::ResourceKind::FeedbackTexture2D:
+    return "fbtex2d";
+  case dxil::ResourceKind::FeedbackTexture2DArray:
+    return "fbtex2darray";
+  case dxil::ResourceKind::Texture2DMS:
+    return "2dMS";
+  case dxil::ResourceKind::Texture2DMSArray:
+    return "2darrayMS";
+  case dxil::ResourceKind::Invalid:
+  case dxil::ResourceKind::NumEntries:
+  case dxil::ResourceKind::CBuffer:
+  case dxil::ResourceKind::RawBuffer:
+  case dxil::ResourceKind::Sampler:
+  case dxil::ResourceKind::StructuredBuffer:
+  case dxil::ResourceKind::TypedBuffer:
+  case dxil::ResourceKind::RTAccelerationStructure:
+    llvm_unreachable("Invalid ResourceKind for texture");
+  }
+  llvm_unreachable("Unhandled ResourceKind");
+}
+
+namespace {
+struct FormatResourceDimension
+    : public llvm::FormatAdapter<const dxil::ResourceInfo &> {
+  explicit FormatResourceDimension(const dxil::ResourceInfo &RI)
+      : llvm::FormatAdapter<const dxil::ResourceInfo &>(RI) {}
+
+  void format(llvm::raw_ostream &OS, StringRef Style) override {
+    dxil::ResourceKind RK = Item.getResourceKind();
+    switch (RK) {
+    default: {
+      OS << getTextureDimName(RK);
+      if (Item.isMultiSample())
+        OS << Item.getMultiSample().Count;
+      break;;
+    }
+    case dxil::ResourceKind::RawBuffer:
+    case dxil::ResourceKind::StructuredBuffer:
+      if (!Item.isUAV())
+        OS << "r/o";
+      else if (Item.getUAV().HasCounter)
+        OS << "r/w+cnt";
+      else
+        OS << "r/w";
+      break;
+    case dxil::ResourceKind::TypedBuffer:
+      OS << "buf";
+      break;
+    case dxil::ResourceKind::RTAccelerationStructure:
+      // TODO: dxc would print "ras" here. Can/should this happen?
+      llvm_unreachable("RTAccelerationStructure printing is not implemented");
+    }
+  }
+};
+
+struct FormatBindingID
+    : public llvm::FormatAdapter<const dxil::ResourceInfo &> {
+  explicit FormatBindingID(const dxil::ResourceInfo &RI)
+      : llvm::FormatAdapter<const dxil::ResourceInfo &>(RI) {}
+
+  void format(llvm::raw_ostream &OS, StringRef Style) override {
+    OS << getRCPrefix(Item.getResourceClass()).upper()
+       << Item.getBinding().RecordID;
+  }
+};
+
+struct FormatBindingLocation
+    : public llvm::FormatAdapter<const dxil::ResourceInfo &> {
+  explicit FormatBindingLocation(const dxil::ResourceInfo &RI)
+      : llvm::FormatAdapter<const dxil::ResourceInfo &>(RI) {}
+
+  void format(llvm::raw_ostream &OS, StringRef Style) override {
+    const auto &Binding = Item.getBinding();
+    OS << getRCPrefix(Item.getResourceClass()) << Binding.LowerBound;
+    if (Binding.Space)
+      OS << ",space" << Binding.Space;
+  }
+};
+
+struct FormatBindingSize
+    : public llvm::FormatAdapter<const dxil::ResourceInfo &> {
+  explicit FormatBindingSize(const dxil::ResourceInfo &RI)
+      : llvm::FormatAdapter<const dxil::ResourceInfo &>(RI) {}
+
+  void format(llvm::raw_ostream &OS, StringRef Style) override {
+    uint32_t Size = Item.getBinding().Size;
+    if (Size == std::numeric_limits<uint32_t>::max())
+      OS << "unbounded";
+    else
+      OS << Size;
+  }
+};
+
+} // namespace
+
+static void prettyPrintResources(raw_ostream &OS, const DXILResourceMap &DRM,
                                  const dxil::Resources &MDResources) {
   // Column widths are arbitrary but match the widths DXC uses.
   OS << ";\n; Resource Bindings:\n;\n";
-  OS << formatv("; {0,-30} {1,10} {2,7} {3,11} {4,7} {5,14} {6,16}\n",
+  OS << formatv("; {0,-30} {1,10} {2,7} {3,11} {4,7} {5,14} {6,9}\n",
                 "Name", "Type", "Format", "Dim", "ID", "HLSL Bind", "Count");
   OS << formatv(
-      "; {0,-+30} {1,-+10} {2,-+7} {3,-+11} {4,-+7} {5,-+14} {6,-+16}\n", "",
+      "; {0,-+30} {1,-+10} {2,-+7} {3,-+11} {4,-+7} {5,-+14} {6,-+9}\n", "",
       "", "", "", "", "", "");
 
+  // TODO: Do we want to sort these by binding or something like that?
+  for (auto [_, RI] : DRM) {
+    dxil::ResourceClass RC = RI.getResourceClass();
+    assert((RC != dxil::ResourceClass::CBuffer || !MDResources.hasCBuffers()) &&
+           "Old and new cbuffer representations can't coexist");
+    assert((RC != dxil::ResourceClass::UAV || !MDResources.hasUAVs()) &&
+           "Old and new UAV representations can't coexist");
+
+    StringRef Name(RI.getName());
+    StringRef Type(getRCName(RC));
+    StringRef Format(getFormatName(RI));
+    FormatResourceDimension Dim(RI);
+    FormatBindingID ID(RI);
+    FormatBindingLocation Bind(RI);
+    FormatBindingSize Count(RI);
+    OS << formatv("; {0,-30} {1,10} {2,7} {3,11} {4,7} {5,14} {6,9}\n",
+                  Name, Type, Format, Dim, ID, Bind, Count);
+  }
+
   if (MDResources.hasCBuffers())
     MDResources.printCBuffers(OS);
   if (MDResources.hasUAVs())
@@ -37,8 +249,9 @@ static void prettyPrintResources(raw_ostream &OS,
 
 PreservedAnalyses DXILPrettyPrinterPass::run(Module &M,
                                              ModuleAnalysisManager &MAM) {
+  const DXILResourceMap &DRM = MAM.getResult<DXILResourceAnalysis>(M);
   const dxil::Resources &MDResources = MAM.getResult<DXILResourceMDAnalysis>(M);
-  prettyPrintResources(OS, MDResources);
+  prettyPrintResources(OS, DRM, MDResources);
   return PreservedAnalyses::all();
 }
 
@@ -63,6 +276,7 @@ class DXILPrettyPrinterLegacy : public llvm::ModulePass {
   bool runOnModule(Module &M) override;
   void getAnalysisUsage(AnalysisUsage &AU) const override {
     AU.setPreservesAll();
+    AU.addRequired<DXILResourceWrapperPass>();
     AU.addRequired<DXILResourceMDWrapper>();
   }
 };
@@ -71,13 +285,16 @@ class DXILPrettyPrinterLegacy : public llvm::ModulePass {
 char DXILPrettyPrinterLegacy::ID = 0;
 INITIALIZE_PASS_BEGIN(DXILPrettyPrinterLegacy, "dxil-pretty-printer",
                       "DXIL Metadata Pretty Printer", true, true)
+INITIALIZE_PASS_DEPENDENCY(DXILResourceWrapperPass)
 INITIALIZE_PASS_DEPENDENCY(DXILResourceMDWrapper)
 INITIALIZE_PASS_END(DXILPrettyPrinterLegacy, "dxil-pretty-printer",
                     "DXIL Metadata Pretty Printer", true, true)
 
 bool DXILPrettyPrinterLegacy::runOnModule(Module &M) {
+  const DXILResourceMap &DRM =
+      getAnalysis<DXILResourceWrapperPass>().getResourceMap();
   dxil::Resources &Res = getAnalysis<DXILResourceMDWrapper>().getDXILResource();
-  prettyPrintResources(OS, Res);
+  prettyPrintResources(OS, DRM, Res);
   return false;
 }
 
diff --git a/llvm/test/CodeGen/DirectX/CreateHandle.ll b/llvm/test/CodeGen/DirectX/CreateHandle.ll
index f0d1c8da5a425..cbb7359642c27 100644
--- a/llvm/test/CodeGen/DirectX/CreateHandle.ll
+++ b/llvm/test/CodeGen/DirectX/CreateHandle.ll
@@ -1,4 +1,13 @@
 ; RUN: opt -S -passes=dxil-op-lower,dxil-translate-metadata %s | FileCheck %s
+; RUN: opt -S -passes=dxil-pretty-printer %s 2>&1 >/dev/null | FileCheck --check-prefix=CHECK-PRETTY %s
+
+; CHECK-PRETTY:       Type  Format         Dim      ID      HLSL Bind     Count
+; CHECK-PRETTY: ---------- ------- ----------- ------- -------------- ---------
+; CHECK-PRETTY:        UAV     f32         buf      U0      u5,space3         1
+; CHECK-PRETTY:        UAV     i32         buf      U1      u7,space2         1
+; CHECK-PRETTY:        SRV     u32         buf      T0      t3,space5        24
+; CHECK-PRETTY:        SRV  struct         r/o      T1      t2,space4         1
+; CHECK-PRETTY:        SRV    byte         r/o      T2      t8,space1         1
 
 target triple = "dxil-pc-shadermodel6.0-compute"
 
diff --git a/llvm/test/CodeGen/DirectX/CreateHandleFromBinding.ll b/llvm/test/CodeGen/DirectX/CreateHandleFromBinding.ll
index 345459a60c5ab..aea251a3612ef 100644
--- a/llvm/test/CodeGen/DirectX/CreateHandleFromBinding.ll
+++ b/llvm/test/CodeGen/DirectX/CreateHandleFromBinding.ll
@@ -1,4 +1,13 @@
 ; RUN: opt -S -passes=dxil-op-lower,dxil-translate-metadata %s | FileCheck %s
+; RUN: opt -S -passes=dxil-pretty-printer %s 2>&1 >/dev/null | FileCheck --check-prefix=CHECK-PRETTY %s
+
+; CHECK-PRETTY:       Type  Format         Dim      ID      HLSL Bind     Count
+; CHECK-PRETTY: ---------- ------- ----------- ------- -------------- ---------
+; CHECK-PRETTY:        UAV     f32         buf      U0      u5,space3         1
+; CHECK-PRETTY:        UAV     i32         buf      U1      u7,space2         1
+; CHECK-PRETTY:        SRV     u32         buf      T0      t3,space5        24
+; CHECK-PRETTY:        SRV  struct         r/o      T1      t2,space4         1
+; CHECK-PRETTY:        SRV    byte         r/o      T2      t8,space1         1
 
 target triple = "dxil-pc-shadermodel6.6-compute"
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/104448


More information about the llvm-branch-commits mailing list