[llvm] [NFC][RootSignature] Use `llvm::EnumEntry` for serialization of Root Signature Elements (PR #144106)
Finn Plummer via llvm-commits
llvm-commits at lists.llvm.org
Fri Jun 13 09:04:52 PDT 2025
https://github.com/inbelic created https://github.com/llvm/llvm-project/pull/144106
- Enables re-use of `printEnum` and `printFlags` methods via templates
- Allows easy definition of `getEnumName` function for enum-to-string conversion, eliminating the need to use a string stream for constructing the Name SmallString
- Also, does a small fix-up of the operands for descriptor table clause to be consistent with other `Build*` methods
>From 5f20b415602edb4718edc9198ec878db3ebb2a7f Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Fri, 6 Jun 2025 21:33:15 +0000
Subject: [PATCH 1/3] nfc: use llvm::EnumEntry to convert Enum to Strings
---
.../Frontend/HLSL/HLSLRootSignatureUtils.cpp | 172 +++++++++---------
1 file changed, 85 insertions(+), 87 deletions(-)
diff --git a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
index 765a3bcbed7e2..79eee0b12b304 100644
--- a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
+++ b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
@@ -15,112 +15,48 @@
#include "llvm/ADT/bit.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Metadata.h"
+#include "llvm/Support/ScopedPrinter.h"
namespace llvm {
namespace hlsl {
namespace rootsig {
-static raw_ostream &operator<<(raw_ostream &OS, const Register &Reg) {
- switch (Reg.ViewType) {
- case RegisterType::BReg:
- OS << "b";
- break;
- case RegisterType::TReg:
- OS << "t";
- break;
- case RegisterType::UReg:
- OS << "u";
- break;
- case RegisterType::SReg:
- OS << "s";
- break;
- }
- OS << Reg.Number;
- return OS;
+template <typename T>
+static StringRef getEnumName(const T Value, ArrayRef<EnumEntry<T>> Enums) {
+ for (const auto &EnumItem : Enums)
+ if (EnumItem.Value == Value)
+ return EnumItem.Name;
+ return "";
}
-static raw_ostream &operator<<(raw_ostream &OS,
- const ShaderVisibility &Visibility) {
- switch (Visibility) {
- case ShaderVisibility::All:
- OS << "All";
- break;
- case ShaderVisibility::Vertex:
- OS << "Vertex";
- break;
- case ShaderVisibility::Hull:
- OS << "Hull";
- break;
- case ShaderVisibility::Domain:
- OS << "Domain";
- break;
- case ShaderVisibility::Geometry:
- OS << "Geometry";
- break;
- case ShaderVisibility::Pixel:
- OS << "Pixel";
- break;
- case ShaderVisibility::Amplification:
- OS << "Amplification";
- break;
- case ShaderVisibility::Mesh:
- OS << "Mesh";
- break;
- }
+template <typename T>
+static raw_ostream &printEnum(raw_ostream &OS, const T Value,
+ ArrayRef<EnumEntry<T>> Enums) {
+ OS << getEnumName(Value, Enums);
return OS;
}
-static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) {
- switch (Type) {
- case ClauseType::CBuffer:
- OS << "CBV";
- break;
- case ClauseType::SRV:
- OS << "SRV";
- break;
- case ClauseType::UAV:
- OS << "UAV";
- break;
- case ClauseType::Sampler:
- OS << "Sampler";
- break;
- }
-
- return OS;
-}
-
-static raw_ostream &operator<<(raw_ostream &OS,
- const DescriptorRangeFlags &Flags) {
+template <typename T>
+static raw_ostream &printFlags(raw_ostream &OS, const T Value,
+ ArrayRef<EnumEntry<T>> Flags) {
bool FlagSet = false;
- unsigned Remaining = llvm::to_underlying(Flags);
+ unsigned Remaining = llvm::to_underlying(Value);
while (Remaining) {
unsigned Bit = 1u << llvm::countr_zero(Remaining);
if (Remaining & Bit) {
if (FlagSet)
OS << " | ";
- switch (static_cast<DescriptorRangeFlags>(Bit)) {
- case DescriptorRangeFlags::DescriptorsVolatile:
- OS << "DescriptorsVolatile";
- break;
- case DescriptorRangeFlags::DataVolatile:
- OS << "DataVolatile";
- break;
- case DescriptorRangeFlags::DataStaticWhileSetAtExecute:
- OS << "DataStaticWhileSetAtExecute";
- break;
- case DescriptorRangeFlags::DataStatic:
- OS << "DataStatic";
- break;
- case DescriptorRangeFlags::DescriptorsStaticKeepingBufferBoundsChecks:
- OS << "DescriptorsStaticKeepingBufferBoundsChecks";
- break;
- default:
+ bool Found = false;
+ for (const auto &FlagItem : Flags)
+ if (FlagItem.Value == T(Bit)) {
+ OS << FlagItem.Name;
+ Found = true;
+ break;
+ }
+ if (!Found)
OS << "invalid: " << Bit;
- break;
- }
-
FlagSet = true;
}
Remaining &= ~Bit;
@@ -128,6 +64,68 @@ static raw_ostream &operator<<(raw_ostream &OS,
if (!FlagSet)
OS << "None";
+ return OS;
+}
+
+static const EnumEntry<RegisterType> RegisterNames[] = {
+ {"b", RegisterType::BReg},
+ {"t", RegisterType::TReg},
+ {"u", RegisterType::UReg},
+ {"s", RegisterType::SReg},
+};
+
+static raw_ostream &operator<<(raw_ostream &OS, const Register &Reg) {
+ printEnum(OS, Reg.ViewType, ArrayRef(RegisterNames));
+ OS << Reg.Number;
+
+ return OS;
+}
+
+static const EnumEntry<ShaderVisibility> VisibilityNames[] = {
+ {"All", ShaderVisibility::All},
+ {"Vertex", ShaderVisibility::Vertex},
+ {"Hull", ShaderVisibility::Hull},
+ {"Domain", ShaderVisibility::Domain},
+ {"Geometry", ShaderVisibility::Geometry},
+ {"Pixel", ShaderVisibility::Pixel},
+ {"Amplification", ShaderVisibility::Amplification},
+ {"Mesh", ShaderVisibility::Mesh},
+};
+
+static raw_ostream &operator<<(raw_ostream &OS,
+ const ShaderVisibility &Visibility) {
+ printEnum(OS, Visibility, ArrayRef(VisibilityNames));
+
+ return OS;
+}
+
+static const EnumEntry<dxil::ResourceClass> ResourceClassNames[] = {
+ {"CBV", dxil::ResourceClass::CBuffer},
+ {"SRV", dxil::ResourceClass::SRV},
+ {"UAV", dxil::ResourceClass::UAV},
+ {"Sampler", dxil::ResourceClass::Sampler},
+};
+
+static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) {
+ printEnum(OS, dxil::ResourceClass(llvm::to_underlying(Type)),
+ ArrayRef(ResourceClassNames));
+
+ return OS;
+}
+
+static const EnumEntry<DescriptorRangeFlags> DescriptorRangeFlagNames[] = {
+ {"DescriptorsVolatile", DescriptorRangeFlags::DescriptorsVolatile},
+ {"DataVolatile", DescriptorRangeFlags::DataVolatile},
+ {"DataStaticWhileSetAtExecute",
+ DescriptorRangeFlags::DataStaticWhileSetAtExecute},
+ {"DataStatic", DescriptorRangeFlags::DataStatic},
+ {"DescriptorsStaticKeepingBufferBoundsChecks",
+ DescriptorRangeFlags::DescriptorsStaticKeepingBufferBoundsChecks},
+};
+
+static raw_ostream &operator<<(raw_ostream &OS,
+ const DescriptorRangeFlags &Flags) {
+ printFlags(OS, Flags, ArrayRef(DescriptorRangeFlagNames));
return OS;
}
>From 02931201d5d7dd11d3e1a2ac7820781103f565b7 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Fri, 6 Jun 2025 21:43:16 +0000
Subject: [PATCH 2/3] nfc: use getEnumName instead of operator<<
---
.../Frontend/HLSL/HLSLRootSignatureUtils.cpp | 18 +++++++++---------
1 file changed, 9 insertions(+), 9 deletions(-)
diff --git a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
index 79eee0b12b304..3d8f90399dfc5 100644
--- a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
+++ b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
@@ -234,12 +234,12 @@ MDNode *MetadataBuilder::BuildRootConstants(const RootConstants &Constants) {
MDNode *MetadataBuilder::BuildRootDescriptor(const RootDescriptor &Descriptor) {
IRBuilder<> Builder(Ctx);
- llvm::SmallString<7> Name;
- llvm::raw_svector_ostream OS(Name);
- OS << "Root" << ClauseType(llvm::to_underlying(Descriptor.Type));
-
+ StringRef TypeName =
+ getEnumName(dxil::ResourceClass(llvm::to_underlying(Descriptor.Type)),
+ ArrayRef(ResourceClassNames));
+ llvm::SmallString<7> Name({"Root", TypeName});
Metadata *Operands[] = {
- MDString::get(Ctx, OS.str()),
+ MDString::get(Ctx, Name),
ConstantAsMetadata::get(
Builder.getInt32(llvm::to_underlying(Descriptor.Visibility))),
ConstantAsMetadata::get(Builder.getInt32(Descriptor.Reg.Number)),
@@ -275,12 +275,12 @@ MDNode *MetadataBuilder::BuildDescriptorTable(const DescriptorTable &Table) {
MDNode *MetadataBuilder::BuildDescriptorTableClause(
const DescriptorTableClause &Clause) {
IRBuilder<> Builder(Ctx);
- std::string Name;
- llvm::raw_string_ostream OS(Name);
- OS << Clause.Type;
+ StringRef Name =
+ getEnumName(dxil::ResourceClass(llvm::to_underlying(Clause.Type)),
+ ArrayRef(ResourceClassNames));
return MDNode::get(
Ctx, {
- MDString::get(Ctx, OS.str()),
+ MDString::get(Ctx, Name),
ConstantAsMetadata::get(Builder.getInt32(Clause.NumDescriptors)),
ConstantAsMetadata::get(Builder.getInt32(Clause.Reg.Number)),
ConstantAsMetadata::get(Builder.getInt32(Clause.Space)),
>From d8cc1f9976aba530f882c4883aaa985d1e33f518 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Fri, 6 Jun 2025 21:44:24 +0000
Subject: [PATCH 3/3] nfc: use operands to fix formatting
---
.../Frontend/HLSL/HLSLRootSignatureUtils.cpp | 20 +++++++++----------
1 file changed, 10 insertions(+), 10 deletions(-)
diff --git a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
index 3d8f90399dfc5..ab5ced523996a 100644
--- a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
+++ b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
@@ -278,16 +278,16 @@ MDNode *MetadataBuilder::BuildDescriptorTableClause(
StringRef Name =
getEnumName(dxil::ResourceClass(llvm::to_underlying(Clause.Type)),
ArrayRef(ResourceClassNames));
- return MDNode::get(
- Ctx, {
- MDString::get(Ctx, Name),
- ConstantAsMetadata::get(Builder.getInt32(Clause.NumDescriptors)),
- ConstantAsMetadata::get(Builder.getInt32(Clause.Reg.Number)),
- ConstantAsMetadata::get(Builder.getInt32(Clause.Space)),
- ConstantAsMetadata::get(Builder.getInt32(Clause.Offset)),
- ConstantAsMetadata::get(
- Builder.getInt32(llvm::to_underlying(Clause.Flags))),
- });
+ Metadata *Operands[] = {
+ MDString::get(Ctx, Name),
+ ConstantAsMetadata::get(Builder.getInt32(Clause.NumDescriptors)),
+ ConstantAsMetadata::get(Builder.getInt32(Clause.Reg.Number)),
+ ConstantAsMetadata::get(Builder.getInt32(Clause.Space)),
+ ConstantAsMetadata::get(Builder.getInt32(Clause.Offset)),
+ ConstantAsMetadata::get(
+ Builder.getInt32(llvm::to_underlying(Clause.Flags))),
+ };
+ return MDNode::get(Ctx, Operands);
}
MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) {
More information about the llvm-commits
mailing list