[llvm] [NFC][RootSignature] Use `llvm::EnumEntry` for serialization of Root Signature Elements (PR #144106)

via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 13 09:05:23 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-hlsl

Author: Finn Plummer (inbelic)

<details>
<summary>Changes</summary>

- Enables re-use of `printEnum` and `printFlags` methods via templates
- Allows easy definition of `getEnumName` function for enum-to-string conversion, eliminating the need to use a string stream for constructing the Name SmallString

- Also, does a small fix-up of the operands for descriptor table clause to be consistent with other `Build*` methods

---
Full diff: https://github.com/llvm/llvm-project/pull/144106.diff


1 Files Affected:

- (modified) llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp (+103-105) 


``````````diff
diff --git a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
index 765a3bcbed7e2..ab5ced523996a 100644
--- a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
+++ b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
@@ -15,112 +15,48 @@
 #include "llvm/ADT/bit.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Metadata.h"
+#include "llvm/Support/ScopedPrinter.h"
 
 namespace llvm {
 namespace hlsl {
 namespace rootsig {
 
-static raw_ostream &operator<<(raw_ostream &OS, const Register &Reg) {
-  switch (Reg.ViewType) {
-  case RegisterType::BReg:
-    OS << "b";
-    break;
-  case RegisterType::TReg:
-    OS << "t";
-    break;
-  case RegisterType::UReg:
-    OS << "u";
-    break;
-  case RegisterType::SReg:
-    OS << "s";
-    break;
-  }
-  OS << Reg.Number;
-  return OS;
+template <typename T>
+static StringRef getEnumName(const T Value, ArrayRef<EnumEntry<T>> Enums) {
+  for (const auto &EnumItem : Enums)
+    if (EnumItem.Value == Value)
+      return EnumItem.Name;
+  return "";
 }
 
-static raw_ostream &operator<<(raw_ostream &OS,
-                               const ShaderVisibility &Visibility) {
-  switch (Visibility) {
-  case ShaderVisibility::All:
-    OS << "All";
-    break;
-  case ShaderVisibility::Vertex:
-    OS << "Vertex";
-    break;
-  case ShaderVisibility::Hull:
-    OS << "Hull";
-    break;
-  case ShaderVisibility::Domain:
-    OS << "Domain";
-    break;
-  case ShaderVisibility::Geometry:
-    OS << "Geometry";
-    break;
-  case ShaderVisibility::Pixel:
-    OS << "Pixel";
-    break;
-  case ShaderVisibility::Amplification:
-    OS << "Amplification";
-    break;
-  case ShaderVisibility::Mesh:
-    OS << "Mesh";
-    break;
-  }
-
-  return OS;
-}
-
-static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) {
-  switch (Type) {
-  case ClauseType::CBuffer:
-    OS << "CBV";
-    break;
-  case ClauseType::SRV:
-    OS << "SRV";
-    break;
-  case ClauseType::UAV:
-    OS << "UAV";
-    break;
-  case ClauseType::Sampler:
-    OS << "Sampler";
-    break;
-  }
+template <typename T>
+static raw_ostream &printEnum(raw_ostream &OS, const T Value,
+                              ArrayRef<EnumEntry<T>> Enums) {
+  OS << getEnumName(Value, Enums);
 
   return OS;
 }
 
-static raw_ostream &operator<<(raw_ostream &OS,
-                               const DescriptorRangeFlags &Flags) {
+template <typename T>
+static raw_ostream &printFlags(raw_ostream &OS, const T Value,
+                               ArrayRef<EnumEntry<T>> Flags) {
   bool FlagSet = false;
-  unsigned Remaining = llvm::to_underlying(Flags);
+  unsigned Remaining = llvm::to_underlying(Value);
   while (Remaining) {
     unsigned Bit = 1u << llvm::countr_zero(Remaining);
     if (Remaining & Bit) {
       if (FlagSet)
         OS << " | ";
 
-      switch (static_cast<DescriptorRangeFlags>(Bit)) {
-      case DescriptorRangeFlags::DescriptorsVolatile:
-        OS << "DescriptorsVolatile";
-        break;
-      case DescriptorRangeFlags::DataVolatile:
-        OS << "DataVolatile";
-        break;
-      case DescriptorRangeFlags::DataStaticWhileSetAtExecute:
-        OS << "DataStaticWhileSetAtExecute";
-        break;
-      case DescriptorRangeFlags::DataStatic:
-        OS << "DataStatic";
-        break;
-      case DescriptorRangeFlags::DescriptorsStaticKeepingBufferBoundsChecks:
-        OS << "DescriptorsStaticKeepingBufferBoundsChecks";
-        break;
-      default:
+      bool Found = false;
+      for (const auto &FlagItem : Flags)
+        if (FlagItem.Value == T(Bit)) {
+          OS << FlagItem.Name;
+          Found = true;
+          break;
+        }
+      if (!Found)
         OS << "invalid: " << Bit;
-        break;
-      }
-
       FlagSet = true;
     }
     Remaining &= ~Bit;
@@ -128,6 +64,68 @@ static raw_ostream &operator<<(raw_ostream &OS,
 
   if (!FlagSet)
     OS << "None";
+  return OS;
+}
+
+static const EnumEntry<RegisterType> RegisterNames[] = {
+    {"b", RegisterType::BReg},
+    {"t", RegisterType::TReg},
+    {"u", RegisterType::UReg},
+    {"s", RegisterType::SReg},
+};
+
+static raw_ostream &operator<<(raw_ostream &OS, const Register &Reg) {
+  printEnum(OS, Reg.ViewType, ArrayRef(RegisterNames));
+  OS << Reg.Number;
+
+  return OS;
+}
+
+static const EnumEntry<ShaderVisibility> VisibilityNames[] = {
+    {"All", ShaderVisibility::All},
+    {"Vertex", ShaderVisibility::Vertex},
+    {"Hull", ShaderVisibility::Hull},
+    {"Domain", ShaderVisibility::Domain},
+    {"Geometry", ShaderVisibility::Geometry},
+    {"Pixel", ShaderVisibility::Pixel},
+    {"Amplification", ShaderVisibility::Amplification},
+    {"Mesh", ShaderVisibility::Mesh},
+};
+
+static raw_ostream &operator<<(raw_ostream &OS,
+                               const ShaderVisibility &Visibility) {
+  printEnum(OS, Visibility, ArrayRef(VisibilityNames));
+
+  return OS;
+}
+
+static const EnumEntry<dxil::ResourceClass> ResourceClassNames[] = {
+    {"CBV", dxil::ResourceClass::CBuffer},
+    {"SRV", dxil::ResourceClass::SRV},
+    {"UAV", dxil::ResourceClass::UAV},
+    {"Sampler", dxil::ResourceClass::Sampler},
+};
+
+static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) {
+  printEnum(OS, dxil::ResourceClass(llvm::to_underlying(Type)),
+            ArrayRef(ResourceClassNames));
+
+  return OS;
+}
+
+static const EnumEntry<DescriptorRangeFlags> DescriptorRangeFlagNames[] = {
+    {"DescriptorsVolatile", DescriptorRangeFlags::DescriptorsVolatile},
+    {"DataVolatile", DescriptorRangeFlags::DataVolatile},
+    {"DataStaticWhileSetAtExecute",
+     DescriptorRangeFlags::DataStaticWhileSetAtExecute},
+    {"DataStatic", DescriptorRangeFlags::DataStatic},
+    {"DescriptorsStaticKeepingBufferBoundsChecks",
+     DescriptorRangeFlags::DescriptorsStaticKeepingBufferBoundsChecks},
+};
+
+static raw_ostream &operator<<(raw_ostream &OS,
+                               const DescriptorRangeFlags &Flags) {
+  printFlags(OS, Flags, ArrayRef(DescriptorRangeFlagNames));
 
   return OS;
 }
@@ -236,12 +234,12 @@ MDNode *MetadataBuilder::BuildRootConstants(const RootConstants &Constants) {
 
 MDNode *MetadataBuilder::BuildRootDescriptor(const RootDescriptor &Descriptor) {
   IRBuilder<> Builder(Ctx);
-  llvm::SmallString<7> Name;
-  llvm::raw_svector_ostream OS(Name);
-  OS << "Root" << ClauseType(llvm::to_underlying(Descriptor.Type));
-
+  StringRef TypeName =
+      getEnumName(dxil::ResourceClass(llvm::to_underlying(Descriptor.Type)),
+                  ArrayRef(ResourceClassNames));
+  llvm::SmallString<7> Name({"Root", TypeName});
   Metadata *Operands[] = {
-      MDString::get(Ctx, OS.str()),
+      MDString::get(Ctx, Name),
       ConstantAsMetadata::get(
           Builder.getInt32(llvm::to_underlying(Descriptor.Visibility))),
       ConstantAsMetadata::get(Builder.getInt32(Descriptor.Reg.Number)),
@@ -277,19 +275,19 @@ MDNode *MetadataBuilder::BuildDescriptorTable(const DescriptorTable &Table) {
 MDNode *MetadataBuilder::BuildDescriptorTableClause(
     const DescriptorTableClause &Clause) {
   IRBuilder<> Builder(Ctx);
-  std::string Name;
-  llvm::raw_string_ostream OS(Name);
-  OS << Clause.Type;
-  return MDNode::get(
-      Ctx, {
-               MDString::get(Ctx, OS.str()),
-               ConstantAsMetadata::get(Builder.getInt32(Clause.NumDescriptors)),
-               ConstantAsMetadata::get(Builder.getInt32(Clause.Reg.Number)),
-               ConstantAsMetadata::get(Builder.getInt32(Clause.Space)),
-               ConstantAsMetadata::get(Builder.getInt32(Clause.Offset)),
-               ConstantAsMetadata::get(
-                   Builder.getInt32(llvm::to_underlying(Clause.Flags))),
-           });
+  StringRef Name =
+      getEnumName(dxil::ResourceClass(llvm::to_underlying(Clause.Type)),
+                  ArrayRef(ResourceClassNames));
+  Metadata *Operands[] = {
+      MDString::get(Ctx, Name),
+      ConstantAsMetadata::get(Builder.getInt32(Clause.NumDescriptors)),
+      ConstantAsMetadata::get(Builder.getInt32(Clause.Reg.Number)),
+      ConstantAsMetadata::get(Builder.getInt32(Clause.Space)),
+      ConstantAsMetadata::get(Builder.getInt32(Clause.Offset)),
+      ConstantAsMetadata::get(
+          Builder.getInt32(llvm::to_underlying(Clause.Flags))),
+  };
+  return MDNode::get(Ctx, Operands);
 }
 
 MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/144106


More information about the llvm-commits mailing list