[llvm] [NFC][RootSignature] Use `llvm::EnumEntry` for serialization of Root Signature Elements (PR #144106)

Finn Plummer via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 13 09:04:52 PDT 2025


https://github.com/inbelic created https://github.com/llvm/llvm-project/pull/144106

- Enables re-use of `printEnum` and `printFlags` methods via templates
- Allows easy definition of `getEnumName` function for enum-to-string conversion, eliminating the need to use a string stream for constructing the Name SmallString

- Also, does a small fix-up of the operands for descriptor table clause to be consistent with other `Build*` methods

>From 5f20b415602edb4718edc9198ec878db3ebb2a7f Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Fri, 6 Jun 2025 21:33:15 +0000
Subject: [PATCH 1/3] nfc: use llvm::EnumEntry to convert Enum to Strings

---
 .../Frontend/HLSL/HLSLRootSignatureUtils.cpp  | 172 +++++++++---------
 1 file changed, 85 insertions(+), 87 deletions(-)

diff --git a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
index 765a3bcbed7e2..79eee0b12b304 100644
--- a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
+++ b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
@@ -15,112 +15,48 @@
 #include "llvm/ADT/bit.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Metadata.h"
+#include "llvm/Support/ScopedPrinter.h"
 
 namespace llvm {
 namespace hlsl {
 namespace rootsig {
 
-static raw_ostream &operator<<(raw_ostream &OS, const Register &Reg) {
-  switch (Reg.ViewType) {
-  case RegisterType::BReg:
-    OS << "b";
-    break;
-  case RegisterType::TReg:
-    OS << "t";
-    break;
-  case RegisterType::UReg:
-    OS << "u";
-    break;
-  case RegisterType::SReg:
-    OS << "s";
-    break;
-  }
-  OS << Reg.Number;
-  return OS;
+template <typename T>
+static StringRef getEnumName(const T Value, ArrayRef<EnumEntry<T>> Enums) {
+  for (const auto &EnumItem : Enums)
+    if (EnumItem.Value == Value)
+      return EnumItem.Name;
+  return "";
 }
 
-static raw_ostream &operator<<(raw_ostream &OS,
-                               const ShaderVisibility &Visibility) {
-  switch (Visibility) {
-  case ShaderVisibility::All:
-    OS << "All";
-    break;
-  case ShaderVisibility::Vertex:
-    OS << "Vertex";
-    break;
-  case ShaderVisibility::Hull:
-    OS << "Hull";
-    break;
-  case ShaderVisibility::Domain:
-    OS << "Domain";
-    break;
-  case ShaderVisibility::Geometry:
-    OS << "Geometry";
-    break;
-  case ShaderVisibility::Pixel:
-    OS << "Pixel";
-    break;
-  case ShaderVisibility::Amplification:
-    OS << "Amplification";
-    break;
-  case ShaderVisibility::Mesh:
-    OS << "Mesh";
-    break;
-  }
+template <typename T>
+static raw_ostream &printEnum(raw_ostream &OS, const T Value,
+                              ArrayRef<EnumEntry<T>> Enums) {
+  OS << getEnumName(Value, Enums);
 
   return OS;
 }
 
-static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) {
-  switch (Type) {
-  case ClauseType::CBuffer:
-    OS << "CBV";
-    break;
-  case ClauseType::SRV:
-    OS << "SRV";
-    break;
-  case ClauseType::UAV:
-    OS << "UAV";
-    break;
-  case ClauseType::Sampler:
-    OS << "Sampler";
-    break;
-  }
-
-  return OS;
-}
-
-static raw_ostream &operator<<(raw_ostream &OS,
-                               const DescriptorRangeFlags &Flags) {
+template <typename T>
+static raw_ostream &printFlags(raw_ostream &OS, const T Value,
+                               ArrayRef<EnumEntry<T>> Flags) {
   bool FlagSet = false;
-  unsigned Remaining = llvm::to_underlying(Flags);
+  unsigned Remaining = llvm::to_underlying(Value);
   while (Remaining) {
     unsigned Bit = 1u << llvm::countr_zero(Remaining);
     if (Remaining & Bit) {
       if (FlagSet)
         OS << " | ";
 
-      switch (static_cast<DescriptorRangeFlags>(Bit)) {
-      case DescriptorRangeFlags::DescriptorsVolatile:
-        OS << "DescriptorsVolatile";
-        break;
-      case DescriptorRangeFlags::DataVolatile:
-        OS << "DataVolatile";
-        break;
-      case DescriptorRangeFlags::DataStaticWhileSetAtExecute:
-        OS << "DataStaticWhileSetAtExecute";
-        break;
-      case DescriptorRangeFlags::DataStatic:
-        OS << "DataStatic";
-        break;
-      case DescriptorRangeFlags::DescriptorsStaticKeepingBufferBoundsChecks:
-        OS << "DescriptorsStaticKeepingBufferBoundsChecks";
-        break;
-      default:
+      bool Found = false;
+      for (const auto &FlagItem : Flags)
+        if (FlagItem.Value == T(Bit)) {
+          OS << FlagItem.Name;
+          Found = true;
+          break;
+        }
+      if (!Found)
         OS << "invalid: " << Bit;
-        break;
-      }
-
       FlagSet = true;
     }
     Remaining &= ~Bit;
@@ -128,6 +64,68 @@ static raw_ostream &operator<<(raw_ostream &OS,
 
   if (!FlagSet)
     OS << "None";
+  return OS;
+}
+
+static const EnumEntry<RegisterType> RegisterNames[] = {
+    {"b", RegisterType::BReg},
+    {"t", RegisterType::TReg},
+    {"u", RegisterType::UReg},
+    {"s", RegisterType::SReg},
+};
+
+static raw_ostream &operator<<(raw_ostream &OS, const Register &Reg) {
+  printEnum(OS, Reg.ViewType, ArrayRef(RegisterNames));
+  OS << Reg.Number;
+
+  return OS;
+}
+
+static const EnumEntry<ShaderVisibility> VisibilityNames[] = {
+    {"All", ShaderVisibility::All},
+    {"Vertex", ShaderVisibility::Vertex},
+    {"Hull", ShaderVisibility::Hull},
+    {"Domain", ShaderVisibility::Domain},
+    {"Geometry", ShaderVisibility::Geometry},
+    {"Pixel", ShaderVisibility::Pixel},
+    {"Amplification", ShaderVisibility::Amplification},
+    {"Mesh", ShaderVisibility::Mesh},
+};
+
+static raw_ostream &operator<<(raw_ostream &OS,
+                               const ShaderVisibility &Visibility) {
+  printEnum(OS, Visibility, ArrayRef(VisibilityNames));
+
+  return OS;
+}
+
+static const EnumEntry<dxil::ResourceClass> ResourceClassNames[] = {
+    {"CBV", dxil::ResourceClass::CBuffer},
+    {"SRV", dxil::ResourceClass::SRV},
+    {"UAV", dxil::ResourceClass::UAV},
+    {"Sampler", dxil::ResourceClass::Sampler},
+};
+
+static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) {
+  printEnum(OS, dxil::ResourceClass(llvm::to_underlying(Type)),
+            ArrayRef(ResourceClassNames));
+
+  return OS;
+}
+
+static const EnumEntry<DescriptorRangeFlags> DescriptorRangeFlagNames[] = {
+    {"DescriptorsVolatile", DescriptorRangeFlags::DescriptorsVolatile},
+    {"DataVolatile", DescriptorRangeFlags::DataVolatile},
+    {"DataStaticWhileSetAtExecute",
+     DescriptorRangeFlags::DataStaticWhileSetAtExecute},
+    {"DataStatic", DescriptorRangeFlags::DataStatic},
+    {"DescriptorsStaticKeepingBufferBoundsChecks",
+     DescriptorRangeFlags::DescriptorsStaticKeepingBufferBoundsChecks},
+};
+
+static raw_ostream &operator<<(raw_ostream &OS,
+                               const DescriptorRangeFlags &Flags) {
+  printFlags(OS, Flags, ArrayRef(DescriptorRangeFlagNames));
 
   return OS;
 }

>From 02931201d5d7dd11d3e1a2ac7820781103f565b7 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Fri, 6 Jun 2025 21:43:16 +0000
Subject: [PATCH 2/3] nfc: use getEnumName instead of operator<<

---
 .../Frontend/HLSL/HLSLRootSignatureUtils.cpp   | 18 +++++++++---------
 1 file changed, 9 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
index 79eee0b12b304..3d8f90399dfc5 100644
--- a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
+++ b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
@@ -234,12 +234,12 @@ MDNode *MetadataBuilder::BuildRootConstants(const RootConstants &Constants) {
 
 MDNode *MetadataBuilder::BuildRootDescriptor(const RootDescriptor &Descriptor) {
   IRBuilder<> Builder(Ctx);
-  llvm::SmallString<7> Name;
-  llvm::raw_svector_ostream OS(Name);
-  OS << "Root" << ClauseType(llvm::to_underlying(Descriptor.Type));
-
+  StringRef TypeName =
+      getEnumName(dxil::ResourceClass(llvm::to_underlying(Descriptor.Type)),
+                  ArrayRef(ResourceClassNames));
+  llvm::SmallString<7> Name({"Root", TypeName});
   Metadata *Operands[] = {
-      MDString::get(Ctx, OS.str()),
+      MDString::get(Ctx, Name),
       ConstantAsMetadata::get(
           Builder.getInt32(llvm::to_underlying(Descriptor.Visibility))),
       ConstantAsMetadata::get(Builder.getInt32(Descriptor.Reg.Number)),
@@ -275,12 +275,12 @@ MDNode *MetadataBuilder::BuildDescriptorTable(const DescriptorTable &Table) {
 MDNode *MetadataBuilder::BuildDescriptorTableClause(
     const DescriptorTableClause &Clause) {
   IRBuilder<> Builder(Ctx);
-  std::string Name;
-  llvm::raw_string_ostream OS(Name);
-  OS << Clause.Type;
+  StringRef Name =
+      getEnumName(dxil::ResourceClass(llvm::to_underlying(Clause.Type)),
+                  ArrayRef(ResourceClassNames));
   return MDNode::get(
       Ctx, {
-               MDString::get(Ctx, OS.str()),
+               MDString::get(Ctx, Name),
                ConstantAsMetadata::get(Builder.getInt32(Clause.NumDescriptors)),
                ConstantAsMetadata::get(Builder.getInt32(Clause.Reg.Number)),
                ConstantAsMetadata::get(Builder.getInt32(Clause.Space)),

>From d8cc1f9976aba530f882c4883aaa985d1e33f518 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Fri, 6 Jun 2025 21:44:24 +0000
Subject: [PATCH 3/3] nfc: use operands to fix formatting

---
 .../Frontend/HLSL/HLSLRootSignatureUtils.cpp  | 20 +++++++++----------
 1 file changed, 10 insertions(+), 10 deletions(-)

diff --git a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
index 3d8f90399dfc5..ab5ced523996a 100644
--- a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
+++ b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
@@ -278,16 +278,16 @@ MDNode *MetadataBuilder::BuildDescriptorTableClause(
   StringRef Name =
       getEnumName(dxil::ResourceClass(llvm::to_underlying(Clause.Type)),
                   ArrayRef(ResourceClassNames));
-  return MDNode::get(
-      Ctx, {
-               MDString::get(Ctx, Name),
-               ConstantAsMetadata::get(Builder.getInt32(Clause.NumDescriptors)),
-               ConstantAsMetadata::get(Builder.getInt32(Clause.Reg.Number)),
-               ConstantAsMetadata::get(Builder.getInt32(Clause.Space)),
-               ConstantAsMetadata::get(Builder.getInt32(Clause.Offset)),
-               ConstantAsMetadata::get(
-                   Builder.getInt32(llvm::to_underlying(Clause.Flags))),
-           });
+  Metadata *Operands[] = {
+      MDString::get(Ctx, Name),
+      ConstantAsMetadata::get(Builder.getInt32(Clause.NumDescriptors)),
+      ConstantAsMetadata::get(Builder.getInt32(Clause.Reg.Number)),
+      ConstantAsMetadata::get(Builder.getInt32(Clause.Space)),
+      ConstantAsMetadata::get(Builder.getInt32(Clause.Offset)),
+      ConstantAsMetadata::get(
+          Builder.getInt32(llvm::to_underlying(Clause.Flags))),
+  };
+  return MDNode::get(Ctx, Operands);
 }
 
 MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) {



More information about the llvm-commits mailing list