[llvm] 63b80dd - [NFC][RootSignature] Use `llvm::EnumEntry` for serialization of Root Signature Elements (#144106)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Jun 16 11:45:23 PDT 2025
Author: Finn Plummer
Date: 2025-06-16T11:45:19-07:00
New Revision: 63b80dd01dafc92104ee43e4f0f5296d644c25ec
URL: https://github.com/llvm/llvm-project/commit/63b80dd01dafc92104ee43e4f0f5296d644c25ec
DIFF: https://github.com/llvm/llvm-project/commit/63b80dd01dafc92104ee43e4f0f5296d644c25ec.diff
LOG: [NFC][RootSignature] Use `llvm::EnumEntry` for serialization of Root Signature Elements (#144106)
It has pointed out
[here](https://github.com/llvm/llvm-project/pull/143198#discussion_r2132877388)
that we may be able to use `llvm::EnumEntry` so that we can re-use the
printing logic across enumerations.
- Enables re-use of `printEnum` and `printFlags` methods via templates
- Allows easy definition of `getEnumName` function for enum-to-string
conversion, eliminating the need to use a string stream for constructing
the Name SmallString
- Also, does a small fix-up of the operands for descriptor table clause
to be consistent with other `Build*` methods
For reference, the
[test-cases](https://github.com/llvm/llvm-project/blob/main/llvm/unittests/Frontend/HLSLRootSignatureDumpTest.cpp)
that must not change expected output.
Added:
Modified:
llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
index 765a3bcbed7e2..7d744781da04f 100644
--- a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
+++ b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
@@ -15,111 +15,46 @@
#include "llvm/ADT/bit.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Metadata.h"
+#include "llvm/Support/ScopedPrinter.h"
namespace llvm {
namespace hlsl {
namespace rootsig {
-static raw_ostream &operator<<(raw_ostream &OS, const Register &Reg) {
- switch (Reg.ViewType) {
- case RegisterType::BReg:
- OS << "b";
- break;
- case RegisterType::TReg:
- OS << "t";
- break;
- case RegisterType::UReg:
- OS << "u";
- break;
- case RegisterType::SReg:
- OS << "s";
- break;
- }
- OS << Reg.Number;
- return OS;
+template <typename T>
+static std::optional<StringRef> getEnumName(const T Value,
+ ArrayRef<EnumEntry<T>> Enums) {
+ for (const auto &EnumItem : Enums)
+ if (EnumItem.Value == Value)
+ return EnumItem.Name;
+ return std::nullopt;
}
-static raw_ostream &operator<<(raw_ostream &OS,
- const ShaderVisibility &Visibility) {
- switch (Visibility) {
- case ShaderVisibility::All:
- OS << "All";
- break;
- case ShaderVisibility::Vertex:
- OS << "Vertex";
- break;
- case ShaderVisibility::Hull:
- OS << "Hull";
- break;
- case ShaderVisibility::Domain:
- OS << "Domain";
- break;
- case ShaderVisibility::Geometry:
- OS << "Geometry";
- break;
- case ShaderVisibility::Pixel:
- OS << "Pixel";
- break;
- case ShaderVisibility::Amplification:
- OS << "Amplification";
- break;
- case ShaderVisibility::Mesh:
- OS << "Mesh";
- break;
- }
-
- return OS;
-}
-
-static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) {
- switch (Type) {
- case ClauseType::CBuffer:
- OS << "CBV";
- break;
- case ClauseType::SRV:
- OS << "SRV";
- break;
- case ClauseType::UAV:
- OS << "UAV";
- break;
- case ClauseType::Sampler:
- OS << "Sampler";
- break;
- }
-
+template <typename T>
+static raw_ostream &printEnum(raw_ostream &OS, const T Value,
+ ArrayRef<EnumEntry<T>> Enums) {
+ auto MaybeName = getEnumName(Value, Enums);
+ if (MaybeName)
+ OS << *MaybeName;
return OS;
}
-static raw_ostream &operator<<(raw_ostream &OS,
- const DescriptorRangeFlags &Flags) {
+template <typename T>
+static raw_ostream &printFlags(raw_ostream &OS, const T Value,
+ ArrayRef<EnumEntry<T>> Flags) {
bool FlagSet = false;
- unsigned Remaining = llvm::to_underlying(Flags);
+ unsigned Remaining = llvm::to_underlying(Value);
while (Remaining) {
unsigned Bit = 1u << llvm::countr_zero(Remaining);
if (Remaining & Bit) {
if (FlagSet)
OS << " | ";
- switch (static_cast<DescriptorRangeFlags>(Bit)) {
- case DescriptorRangeFlags::DescriptorsVolatile:
- OS << "DescriptorsVolatile";
- break;
- case DescriptorRangeFlags::DataVolatile:
- OS << "DataVolatile";
- break;
- case DescriptorRangeFlags::DataStaticWhileSetAtExecute:
- OS << "DataStaticWhileSetAtExecute";
- break;
- case DescriptorRangeFlags::DataStatic:
- OS << "DataStatic";
- break;
- case DescriptorRangeFlags::DescriptorsStaticKeepingBufferBoundsChecks:
- OS << "DescriptorsStaticKeepingBufferBoundsChecks";
- break;
- default:
+ auto MaybeFlag = getEnumName(T(Bit), Flags);
+ if (MaybeFlag)
+ OS << *MaybeFlag;
+ else
OS << "invalid: " << Bit;
- break;
- }
FlagSet = true;
}
@@ -128,6 +63,68 @@ static raw_ostream &operator<<(raw_ostream &OS,
if (!FlagSet)
OS << "None";
+ return OS;
+}
+
+static const EnumEntry<RegisterType> RegisterNames[] = {
+ {"b", RegisterType::BReg},
+ {"t", RegisterType::TReg},
+ {"u", RegisterType::UReg},
+ {"s", RegisterType::SReg},
+};
+
+static raw_ostream &operator<<(raw_ostream &OS, const Register &Reg) {
+ printEnum(OS, Reg.ViewType, ArrayRef(RegisterNames));
+ OS << Reg.Number;
+
+ return OS;
+}
+
+static const EnumEntry<ShaderVisibility> VisibilityNames[] = {
+ {"All", ShaderVisibility::All},
+ {"Vertex", ShaderVisibility::Vertex},
+ {"Hull", ShaderVisibility::Hull},
+ {"Domain", ShaderVisibility::Domain},
+ {"Geometry", ShaderVisibility::Geometry},
+ {"Pixel", ShaderVisibility::Pixel},
+ {"Amplification", ShaderVisibility::Amplification},
+ {"Mesh", ShaderVisibility::Mesh},
+};
+
+static raw_ostream &operator<<(raw_ostream &OS,
+ const ShaderVisibility &Visibility) {
+ printEnum(OS, Visibility, ArrayRef(VisibilityNames));
+
+ return OS;
+}
+
+static const EnumEntry<dxil::ResourceClass> ResourceClassNames[] = {
+ {"CBV", dxil::ResourceClass::CBuffer},
+ {"SRV", dxil::ResourceClass::SRV},
+ {"UAV", dxil::ResourceClass::UAV},
+ {"Sampler", dxil::ResourceClass::Sampler},
+};
+
+static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) {
+ printEnum(OS, dxil::ResourceClass(llvm::to_underlying(Type)),
+ ArrayRef(ResourceClassNames));
+
+ return OS;
+}
+
+static const EnumEntry<DescriptorRangeFlags> DescriptorRangeFlagNames[] = {
+ {"DescriptorsVolatile", DescriptorRangeFlags::DescriptorsVolatile},
+ {"DataVolatile", DescriptorRangeFlags::DataVolatile},
+ {"DataStaticWhileSetAtExecute",
+ DescriptorRangeFlags::DataStaticWhileSetAtExecute},
+ {"DataStatic", DescriptorRangeFlags::DataStatic},
+ {"DescriptorsStaticKeepingBufferBoundsChecks",
+ DescriptorRangeFlags::DescriptorsStaticKeepingBufferBoundsChecks},
+};
+
+static raw_ostream &operator<<(raw_ostream &OS,
+ const DescriptorRangeFlags &Flags) {
+ printFlags(OS, Flags, ArrayRef(DescriptorRangeFlagNames));
return OS;
}
@@ -236,12 +233,13 @@ MDNode *MetadataBuilder::BuildRootConstants(const RootConstants &Constants) {
MDNode *MetadataBuilder::BuildRootDescriptor(const RootDescriptor &Descriptor) {
IRBuilder<> Builder(Ctx);
- llvm::SmallString<7> Name;
- llvm::raw_svector_ostream OS(Name);
- OS << "Root" << ClauseType(llvm::to_underlying(Descriptor.Type));
-
+ std::optional<StringRef> TypeName =
+ getEnumName(dxil::ResourceClass(llvm::to_underlying(Descriptor.Type)),
+ ArrayRef(ResourceClassNames));
+ assert(TypeName && "Provided an invalid Resource Class");
+ llvm::SmallString<7> Name({"Root", *TypeName});
Metadata *Operands[] = {
- MDString::get(Ctx, OS.str()),
+ MDString::get(Ctx, Name),
ConstantAsMetadata::get(
Builder.getInt32(llvm::to_underlying(Descriptor.Visibility))),
ConstantAsMetadata::get(Builder.getInt32(Descriptor.Reg.Number)),
@@ -277,19 +275,20 @@ MDNode *MetadataBuilder::BuildDescriptorTable(const DescriptorTable &Table) {
MDNode *MetadataBuilder::BuildDescriptorTableClause(
const DescriptorTableClause &Clause) {
IRBuilder<> Builder(Ctx);
- std::string Name;
- llvm::raw_string_ostream OS(Name);
- OS << Clause.Type;
- return MDNode::get(
- Ctx, {
- MDString::get(Ctx, OS.str()),
- ConstantAsMetadata::get(Builder.getInt32(Clause.NumDescriptors)),
- ConstantAsMetadata::get(Builder.getInt32(Clause.Reg.Number)),
- ConstantAsMetadata::get(Builder.getInt32(Clause.Space)),
- ConstantAsMetadata::get(Builder.getInt32(Clause.Offset)),
- ConstantAsMetadata::get(
- Builder.getInt32(llvm::to_underlying(Clause.Flags))),
- });
+ std::optional<StringRef> Name =
+ getEnumName(dxil::ResourceClass(llvm::to_underlying(Clause.Type)),
+ ArrayRef(ResourceClassNames));
+ assert(Name && "Provided an invalid Resource Class");
+ Metadata *Operands[] = {
+ MDString::get(Ctx, *Name),
+ ConstantAsMetadata::get(Builder.getInt32(Clause.NumDescriptors)),
+ ConstantAsMetadata::get(Builder.getInt32(Clause.Reg.Number)),
+ ConstantAsMetadata::get(Builder.getInt32(Clause.Space)),
+ ConstantAsMetadata::get(Builder.getInt32(Clause.Offset)),
+ ConstantAsMetadata::get(
+ Builder.getInt32(llvm::to_underlying(Clause.Flags))),
+ };
+ return MDNode::get(Ctx, Operands);
}
MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) {
More information about the llvm-commits
mailing list