[llvm] 63b80dd - [NFC][RootSignature] Use `llvm::EnumEntry` for serialization of Root Signature Elements (#144106)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 16 11:45:23 PDT 2025


Author: Finn Plummer
Date: 2025-06-16T11:45:19-07:00
New Revision: 63b80dd01dafc92104ee43e4f0f5296d644c25ec

URL: https://github.com/llvm/llvm-project/commit/63b80dd01dafc92104ee43e4f0f5296d644c25ec
DIFF: https://github.com/llvm/llvm-project/commit/63b80dd01dafc92104ee43e4f0f5296d644c25ec.diff

LOG: [NFC][RootSignature] Use `llvm::EnumEntry` for serialization of Root Signature Elements (#144106)

It has pointed out
[here](https://github.com/llvm/llvm-project/pull/143198#discussion_r2132877388)
that we may be able to use `llvm::EnumEntry` so that we can re-use the
printing logic across enumerations.

- Enables re-use of `printEnum` and `printFlags` methods via templates
- Allows easy definition of `getEnumName` function for enum-to-string
conversion, eliminating the need to use a string stream for constructing
the Name SmallString

- Also, does a small fix-up of the operands for descriptor table clause
to be consistent with other `Build*` methods

For reference, the
[test-cases](https://github.com/llvm/llvm-project/blob/main/llvm/unittests/Frontend/HLSLRootSignatureDumpTest.cpp)
that must not change expected output.

Added: 
    

Modified: 
    llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
index 765a3bcbed7e2..7d744781da04f 100644
--- a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
+++ b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
@@ -15,111 +15,46 @@
 #include "llvm/ADT/bit.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Metadata.h"
+#include "llvm/Support/ScopedPrinter.h"
 
 namespace llvm {
 namespace hlsl {
 namespace rootsig {
 
-static raw_ostream &operator<<(raw_ostream &OS, const Register &Reg) {
-  switch (Reg.ViewType) {
-  case RegisterType::BReg:
-    OS << "b";
-    break;
-  case RegisterType::TReg:
-    OS << "t";
-    break;
-  case RegisterType::UReg:
-    OS << "u";
-    break;
-  case RegisterType::SReg:
-    OS << "s";
-    break;
-  }
-  OS << Reg.Number;
-  return OS;
+template <typename T>
+static std::optional<StringRef> getEnumName(const T Value,
+                                            ArrayRef<EnumEntry<T>> Enums) {
+  for (const auto &EnumItem : Enums)
+    if (EnumItem.Value == Value)
+      return EnumItem.Name;
+  return std::nullopt;
 }
 
-static raw_ostream &operator<<(raw_ostream &OS,
-                               const ShaderVisibility &Visibility) {
-  switch (Visibility) {
-  case ShaderVisibility::All:
-    OS << "All";
-    break;
-  case ShaderVisibility::Vertex:
-    OS << "Vertex";
-    break;
-  case ShaderVisibility::Hull:
-    OS << "Hull";
-    break;
-  case ShaderVisibility::Domain:
-    OS << "Domain";
-    break;
-  case ShaderVisibility::Geometry:
-    OS << "Geometry";
-    break;
-  case ShaderVisibility::Pixel:
-    OS << "Pixel";
-    break;
-  case ShaderVisibility::Amplification:
-    OS << "Amplification";
-    break;
-  case ShaderVisibility::Mesh:
-    OS << "Mesh";
-    break;
-  }
-
-  return OS;
-}
-
-static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) {
-  switch (Type) {
-  case ClauseType::CBuffer:
-    OS << "CBV";
-    break;
-  case ClauseType::SRV:
-    OS << "SRV";
-    break;
-  case ClauseType::UAV:
-    OS << "UAV";
-    break;
-  case ClauseType::Sampler:
-    OS << "Sampler";
-    break;
-  }
-
+template <typename T>
+static raw_ostream &printEnum(raw_ostream &OS, const T Value,
+                              ArrayRef<EnumEntry<T>> Enums) {
+  auto MaybeName = getEnumName(Value, Enums);
+  if (MaybeName)
+    OS << *MaybeName;
   return OS;
 }
 
-static raw_ostream &operator<<(raw_ostream &OS,
-                               const DescriptorRangeFlags &Flags) {
+template <typename T>
+static raw_ostream &printFlags(raw_ostream &OS, const T Value,
+                               ArrayRef<EnumEntry<T>> Flags) {
   bool FlagSet = false;
-  unsigned Remaining = llvm::to_underlying(Flags);
+  unsigned Remaining = llvm::to_underlying(Value);
   while (Remaining) {
     unsigned Bit = 1u << llvm::countr_zero(Remaining);
     if (Remaining & Bit) {
       if (FlagSet)
         OS << " | ";
 
-      switch (static_cast<DescriptorRangeFlags>(Bit)) {
-      case DescriptorRangeFlags::DescriptorsVolatile:
-        OS << "DescriptorsVolatile";
-        break;
-      case DescriptorRangeFlags::DataVolatile:
-        OS << "DataVolatile";
-        break;
-      case DescriptorRangeFlags::DataStaticWhileSetAtExecute:
-        OS << "DataStaticWhileSetAtExecute";
-        break;
-      case DescriptorRangeFlags::DataStatic:
-        OS << "DataStatic";
-        break;
-      case DescriptorRangeFlags::DescriptorsStaticKeepingBufferBoundsChecks:
-        OS << "DescriptorsStaticKeepingBufferBoundsChecks";
-        break;
-      default:
+      auto MaybeFlag = getEnumName(T(Bit), Flags);
+      if (MaybeFlag)
+        OS << *MaybeFlag;
+      else
         OS << "invalid: " << Bit;
-        break;
-      }
 
       FlagSet = true;
     }
@@ -128,6 +63,68 @@ static raw_ostream &operator<<(raw_ostream &OS,
 
   if (!FlagSet)
     OS << "None";
+  return OS;
+}
+
+static const EnumEntry<RegisterType> RegisterNames[] = {
+    {"b", RegisterType::BReg},
+    {"t", RegisterType::TReg},
+    {"u", RegisterType::UReg},
+    {"s", RegisterType::SReg},
+};
+
+static raw_ostream &operator<<(raw_ostream &OS, const Register &Reg) {
+  printEnum(OS, Reg.ViewType, ArrayRef(RegisterNames));
+  OS << Reg.Number;
+
+  return OS;
+}
+
+static const EnumEntry<ShaderVisibility> VisibilityNames[] = {
+    {"All", ShaderVisibility::All},
+    {"Vertex", ShaderVisibility::Vertex},
+    {"Hull", ShaderVisibility::Hull},
+    {"Domain", ShaderVisibility::Domain},
+    {"Geometry", ShaderVisibility::Geometry},
+    {"Pixel", ShaderVisibility::Pixel},
+    {"Amplification", ShaderVisibility::Amplification},
+    {"Mesh", ShaderVisibility::Mesh},
+};
+
+static raw_ostream &operator<<(raw_ostream &OS,
+                               const ShaderVisibility &Visibility) {
+  printEnum(OS, Visibility, ArrayRef(VisibilityNames));
+
+  return OS;
+}
+
+static const EnumEntry<dxil::ResourceClass> ResourceClassNames[] = {
+    {"CBV", dxil::ResourceClass::CBuffer},
+    {"SRV", dxil::ResourceClass::SRV},
+    {"UAV", dxil::ResourceClass::UAV},
+    {"Sampler", dxil::ResourceClass::Sampler},
+};
+
+static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) {
+  printEnum(OS, dxil::ResourceClass(llvm::to_underlying(Type)),
+            ArrayRef(ResourceClassNames));
+
+  return OS;
+}
+
+static const EnumEntry<DescriptorRangeFlags> DescriptorRangeFlagNames[] = {
+    {"DescriptorsVolatile", DescriptorRangeFlags::DescriptorsVolatile},
+    {"DataVolatile", DescriptorRangeFlags::DataVolatile},
+    {"DataStaticWhileSetAtExecute",
+     DescriptorRangeFlags::DataStaticWhileSetAtExecute},
+    {"DataStatic", DescriptorRangeFlags::DataStatic},
+    {"DescriptorsStaticKeepingBufferBoundsChecks",
+     DescriptorRangeFlags::DescriptorsStaticKeepingBufferBoundsChecks},
+};
+
+static raw_ostream &operator<<(raw_ostream &OS,
+                               const DescriptorRangeFlags &Flags) {
+  printFlags(OS, Flags, ArrayRef(DescriptorRangeFlagNames));
 
   return OS;
 }
@@ -236,12 +233,13 @@ MDNode *MetadataBuilder::BuildRootConstants(const RootConstants &Constants) {
 
 MDNode *MetadataBuilder::BuildRootDescriptor(const RootDescriptor &Descriptor) {
   IRBuilder<> Builder(Ctx);
-  llvm::SmallString<7> Name;
-  llvm::raw_svector_ostream OS(Name);
-  OS << "Root" << ClauseType(llvm::to_underlying(Descriptor.Type));
-
+  std::optional<StringRef> TypeName =
+      getEnumName(dxil::ResourceClass(llvm::to_underlying(Descriptor.Type)),
+                  ArrayRef(ResourceClassNames));
+  assert(TypeName && "Provided an invalid Resource Class");
+  llvm::SmallString<7> Name({"Root", *TypeName});
   Metadata *Operands[] = {
-      MDString::get(Ctx, OS.str()),
+      MDString::get(Ctx, Name),
       ConstantAsMetadata::get(
           Builder.getInt32(llvm::to_underlying(Descriptor.Visibility))),
       ConstantAsMetadata::get(Builder.getInt32(Descriptor.Reg.Number)),
@@ -277,19 +275,20 @@ MDNode *MetadataBuilder::BuildDescriptorTable(const DescriptorTable &Table) {
 MDNode *MetadataBuilder::BuildDescriptorTableClause(
     const DescriptorTableClause &Clause) {
   IRBuilder<> Builder(Ctx);
-  std::string Name;
-  llvm::raw_string_ostream OS(Name);
-  OS << Clause.Type;
-  return MDNode::get(
-      Ctx, {
-               MDString::get(Ctx, OS.str()),
-               ConstantAsMetadata::get(Builder.getInt32(Clause.NumDescriptors)),
-               ConstantAsMetadata::get(Builder.getInt32(Clause.Reg.Number)),
-               ConstantAsMetadata::get(Builder.getInt32(Clause.Space)),
-               ConstantAsMetadata::get(Builder.getInt32(Clause.Offset)),
-               ConstantAsMetadata::get(
-                   Builder.getInt32(llvm::to_underlying(Clause.Flags))),
-           });
+  std::optional<StringRef> Name =
+      getEnumName(dxil::ResourceClass(llvm::to_underlying(Clause.Type)),
+                  ArrayRef(ResourceClassNames));
+  assert(Name && "Provided an invalid Resource Class");
+  Metadata *Operands[] = {
+      MDString::get(Ctx, *Name),
+      ConstantAsMetadata::get(Builder.getInt32(Clause.NumDescriptors)),
+      ConstantAsMetadata::get(Builder.getInt32(Clause.Reg.Number)),
+      ConstantAsMetadata::get(Builder.getInt32(Clause.Space)),
+      ConstantAsMetadata::get(Builder.getInt32(Clause.Offset)),
+      ConstantAsMetadata::get(
+          Builder.getInt32(llvm::to_underlying(Clause.Flags))),
+  };
+  return MDNode::get(Ctx, Operands);
 }
 
 MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) {


        


More information about the llvm-commits mailing list