[llvm-branch-commits] [llvm] [HLSL][RootSignature] Implement serialization of remaining Root Elements (PR #143198)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Jun 6 13:14:20 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-hlsl
Author: Finn Plummer (inbelic)
<details>
<summary>Changes</summary>
Implements serialization of the remaining completely defined `RootElement`s, namely `RootDescriptor`s and `RootFlag`s.
- Adds unit testing for the serialization methods
Resolves https://github.com/llvm/llvm-project/issues/138191
Resolves https://github.com/llvm/llvm-project/issues/138193
---
Full diff: https://github.com/llvm/llvm-project/pull/143198.diff
3 Files Affected:
- (modified) llvm/include/llvm/Frontend/HLSL/HLSLRootSignatureUtils.h (+6)
- (modified) llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp (+254)
- (modified) llvm/unittests/Frontend/HLSLRootSignatureDumpTest.cpp (+121)
``````````diff
diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignatureUtils.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignatureUtils.h
index ca20e6719f3a4..7489777670703 100644
--- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignatureUtils.h
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignatureUtils.h
@@ -32,6 +32,12 @@ LLVM_ABI raw_ostream &operator<<(raw_ostream &OS, const RootFlags &Flags);
LLVM_ABI raw_ostream &operator<<(raw_ostream &OS,
const RootConstants &Constants);
+LLVM_ABI raw_ostream &operator<<(raw_ostream &OS,
+ const RootDescriptor &Descriptor);
+
+LLVM_ABI raw_ostream &operator<<(raw_ostream &OS,
+ const StaticSampler &StaticSampler);
+
LLVM_ABI raw_ostream &operator<<(raw_ostream &OS,
const DescriptorTableClause &Clause);
diff --git a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
index 24486a55ecf6a..70c3e72c1f806 100644
--- a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
+++ b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
@@ -71,6 +71,199 @@ static raw_ostream &operator<<(raw_ostream &OS,
return OS;
}
+static raw_ostream &operator<<(raw_ostream &OS, const SamplerFilter &Filter) {
+ switch (Filter) {
+ case SamplerFilter::MinMagMipPoint:
+ OS << "MinMagMipPoint";
+ break;
+ case SamplerFilter::MinMagPointMipLinear:
+ OS << "MinMagPointMipLinear";
+ break;
+ case SamplerFilter::MinPointMagLinearMipPoint:
+ OS << "MinPointMagLinearMipPoint";
+ break;
+ case SamplerFilter::MinPointMagMipLinear:
+ OS << "MinPointMagMipLinear";
+ break;
+ case SamplerFilter::MinLinearMagMipPoint:
+ OS << "MinLinearMagMipPoint";
+ break;
+ case SamplerFilter::MinLinearMagPointMipLinear:
+ OS << "MinLinearMagPointMipLinear";
+ break;
+ case SamplerFilter::MinMagLinearMipPoint:
+ OS << "MinMagLinearMipPoint";
+ break;
+ case SamplerFilter::MinMagMipLinear:
+ OS << "MinMagMipLinear";
+ break;
+ case SamplerFilter::Anisotropic:
+ OS << "Anisotropic";
+ break;
+ case SamplerFilter::ComparisonMinMagMipPoint:
+ OS << "ComparisonMinMagMipPoint";
+ break;
+ case SamplerFilter::ComparisonMinMagPointMipLinear:
+ OS << "ComparisonMinMagPointMipLinear";
+ break;
+ case SamplerFilter::ComparisonMinPointMagLinearMipPoint:
+ OS << "ComparisonMinPointMagLinearMipPoint";
+ break;
+ case SamplerFilter::ComparisonMinPointMagMipLinear:
+ OS << "ComparisonMinPointMagMipLinear";
+ break;
+ case SamplerFilter::ComparisonMinLinearMagMipPoint:
+ OS << "ComparisonMinLinearMagMipPoint";
+ break;
+ case SamplerFilter::ComparisonMinLinearMagPointMipLinear:
+ OS << "ComparisonMinLinearMagPointMipLinear";
+ break;
+ case SamplerFilter::ComparisonMinMagLinearMipPoint:
+ OS << "ComparisonMinMagLinearMipPoint";
+ break;
+ case SamplerFilter::ComparisonMinMagMipLinear:
+ OS << "ComparisonMinMagMipLinear";
+ break;
+ case SamplerFilter::ComparisonAnisotropic:
+ OS << "ComparisonAnisotropic";
+ break;
+ case SamplerFilter::MinimumMinMagMipPoint:
+ OS << "MinimumMinMagMipPoint";
+ break;
+ case SamplerFilter::MinimumMinMagPointMipLinear:
+ OS << "MinimumMinMagPointMipLinear";
+ break;
+ case SamplerFilter::MinimumMinPointMagLinearMipPoint:
+ OS << "MinimumMinPointMagLinearMipPoint";
+ break;
+ case SamplerFilter::MinimumMinPointMagMipLinear:
+ OS << "MinimumMinPointMagMipLinear";
+ break;
+ case SamplerFilter::MinimumMinLinearMagMipPoint:
+ OS << "MinimumMinLinearMagMipPoint";
+ break;
+ case SamplerFilter::MinimumMinLinearMagPointMipLinear:
+ OS << "MinimumMinLinearMagPointMipLinear";
+ break;
+ case SamplerFilter::MinimumMinMagLinearMipPoint:
+ OS << "MinimumMinMagLinearMipPoint";
+ break;
+ case SamplerFilter::MinimumMinMagMipLinear:
+ OS << "MinimumMinMagMipLinear";
+ break;
+ case SamplerFilter::MinimumAnisotropic:
+ OS << "MinimumAnisotropic";
+ break;
+ case SamplerFilter::MaximumMinMagMipPoint:
+ OS << "MaximumMinMagMipPoint";
+ break;
+ case SamplerFilter::MaximumMinMagPointMipLinear:
+ OS << "MaximumMinMagPointMipLinear";
+ break;
+ case SamplerFilter::MaximumMinPointMagLinearMipPoint:
+ OS << "MaximumMinPointMagLinearMipPoint";
+ break;
+ case SamplerFilter::MaximumMinPointMagMipLinear:
+ OS << "MaximumMinPointMagMipLinear";
+ break;
+ case SamplerFilter::MaximumMinLinearMagMipPoint:
+ OS << "MaximumMinLinearMagMipPoint";
+ break;
+ case SamplerFilter::MaximumMinLinearMagPointMipLinear:
+ OS << "MaximumMinLinearMagPointMipLinear";
+ break;
+ case SamplerFilter::MaximumMinMagLinearMipPoint:
+ OS << "MaximumMinMagLinearMipPoint";
+ break;
+ case SamplerFilter::MaximumMinMagMipLinear:
+ OS << "MaximumMinMagMipLinear";
+ break;
+ case SamplerFilter::MaximumAnisotropic:
+ OS << "MaximumAnisotropic";
+ break;
+ }
+
+ return OS;
+}
+
+static raw_ostream &operator<<(raw_ostream &OS,
+ const TextureAddressMode &Address) {
+ switch (Address) {
+ case TextureAddressMode::Wrap:
+ OS << "Wrap";
+ break;
+ case TextureAddressMode::Mirror:
+ OS << "Mirror";
+ break;
+ case TextureAddressMode::Clamp:
+ OS << "Clamp";
+ break;
+ case TextureAddressMode::Border:
+ OS << "Border";
+ break;
+ case TextureAddressMode::MirrorOnce:
+ OS << "MirrorOnce";
+ break;
+ }
+
+ return OS;
+}
+
+static raw_ostream &operator<<(raw_ostream &OS,
+ const ComparisonFunc &CompFunc) {
+ switch (CompFunc) {
+ case ComparisonFunc::Never:
+ OS << "Never";
+ break;
+ case ComparisonFunc::Less:
+ OS << "Less";
+ break;
+ case ComparisonFunc::Equal:
+ OS << "Equal";
+ break;
+ case ComparisonFunc::LessEqual:
+ OS << "LessEqual";
+ break;
+ case ComparisonFunc::Greater:
+ OS << "Greater";
+ break;
+ case ComparisonFunc::NotEqual:
+ OS << "NotEqual";
+ break;
+ case ComparisonFunc::GreaterEqual:
+ OS << "GreaterEqual";
+ break;
+ case ComparisonFunc::Always:
+ OS << "Always";
+ break;
+ }
+
+ return OS;
+}
+
+static raw_ostream &operator<<(raw_ostream &OS,
+ const StaticBorderColor &BorderColor) {
+ switch (BorderColor) {
+ case StaticBorderColor::TransparentBlack:
+ OS << "TransparentBlack";
+ break;
+ case StaticBorderColor::OpaqueBlack:
+ OS << "OpaqueBlack";
+ break;
+ case StaticBorderColor::OpaqueWhite:
+ OS << "OpaqueWhite";
+ break;
+ case StaticBorderColor::OpaqueBlackUint:
+ OS << "OpaqueBlackUint";
+ break;
+ case StaticBorderColor::OpaqueWhiteUint:
+ OS << "OpaqueWhiteUint";
+ break;
+ }
+
+ return OS;
+}
+
static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) {
switch (Type) {
case ClauseType::CBuffer:
@@ -132,6 +325,42 @@ static raw_ostream &operator<<(raw_ostream &OS,
return OS;
}
+static raw_ostream &operator<<(raw_ostream &OS,
+ const RootDescriptorFlags &Flags) {
+ bool FlagSet = false;
+ unsigned Remaining = llvm::to_underlying(Flags);
+ while (Remaining) {
+ unsigned Bit = 1u << llvm::countr_zero(Remaining);
+ if (Remaining & Bit) {
+ if (FlagSet)
+ OS << " | ";
+
+ switch (static_cast<RootDescriptorFlags>(Bit)) {
+ case RootDescriptorFlags::DataVolatile:
+ OS << "DataVolatile";
+ break;
+ case RootDescriptorFlags::DataStaticWhileSetAtExecute:
+ OS << "DataStaticWhileSetAtExecute";
+ break;
+ case RootDescriptorFlags::DataStatic:
+ OS << "DataStatic";
+ break;
+ default:
+ OS << "invalid: " << Bit;
+ break;
+ }
+
+ FlagSet = true;
+ }
+ Remaining &= ~Bit;
+ }
+
+ if (!FlagSet)
+ OS << "None";
+
+ return OS;
+}
+
raw_ostream &operator<<(raw_ostream &OS, const RootFlags &Flags) {
OS << "RootFlags(";
bool FlagSet = false;
@@ -205,6 +434,31 @@ raw_ostream &operator<<(raw_ostream &OS, const RootConstants &Constants) {
return OS;
}
+raw_ostream &operator<<(raw_ostream &OS, const RootDescriptor &Descriptor) {
+ ClauseType Type = ClauseType(llvm::to_underlying(Descriptor.Type));
+ OS << "Root" << Type << "(" << Descriptor.Reg
+ << ", space = " << Descriptor.Space
+ << ", visibility = " << Descriptor.Visibility
+ << ", flags = " << Descriptor.Flags << ")";
+
+ return OS;
+}
+
+raw_ostream &operator<<(raw_ostream &OS, const StaticSampler &Sampler) {
+ OS << "StaticSampler(" << Sampler.Reg << ", filter = " << Sampler.Filter
+ << ", addressU = " << Sampler.AddressU
+ << ", addressV = " << Sampler.AddressV
+ << ", addressW = " << Sampler.AddressW
+ << ", mipLODBias = " << Sampler.MipLODBias
+ << ", maxAnisotropy = " << Sampler.MaxAnisotropy
+ << ", comparisonFunc = " << Sampler.CompFunc
+ << ", borderColor = " << Sampler.BorderColor
+ << ", minLOD = " << Sampler.MinLOD << ", maxLOD = " << Sampler.MaxLOD
+ << ", space = " << Sampler.Space << ", visibility = " << Sampler.Visibility
+ << ")";
+ return OS;
+}
+
raw_ostream &operator<<(raw_ostream &OS, const DescriptorTable &Table) {
OS << "DescriptorTable(numClauses = " << Table.NumClauses
<< ", visibility = " << Table.Visibility << ")";
diff --git a/llvm/unittests/Frontend/HLSLRootSignatureDumpTest.cpp b/llvm/unittests/Frontend/HLSLRootSignatureDumpTest.cpp
index 1a0c8e2a16396..831c5dd585fab 100644
--- a/llvm/unittests/Frontend/HLSLRootSignatureDumpTest.cpp
+++ b/llvm/unittests/Frontend/HLSLRootSignatureDumpTest.cpp
@@ -177,4 +177,125 @@ TEST(HLSLRootSignatureTest, AllRootFlagsDump) {
EXPECT_EQ(Out, Expected);
}
+TEST(HLSLRootSignatureTest, RootCBVDump) {
+ RootDescriptor Descriptor;
+ Descriptor.Type = DescriptorType::CBuffer;
+ Descriptor.Reg = {RegisterType::BReg, 0};
+ Descriptor.setDefaultFlags();
+
+ std::string Out;
+ llvm::raw_string_ostream OS(Out);
+ OS << Descriptor;
+ OS.flush();
+
+ std::string Expected = "RootCBV(b0, space = 0, "
+ "visibility = All, "
+ "flags = DataStaticWhileSetAtExecute)";
+ EXPECT_EQ(Out, Expected);
+}
+
+TEST(HLSLRootSignatureTest, RootSRVDump) {
+ RootDescriptor Descriptor;
+ Descriptor.Type = DescriptorType::SRV;
+ Descriptor.Reg = {RegisterType::TReg, 0};
+ Descriptor.Space = 42;
+ Descriptor.Visibility = ShaderVisibility::Geometry;
+ Descriptor.Flags = RootDescriptorFlags::None;
+
+ std::string Out;
+ llvm::raw_string_ostream OS(Out);
+ OS << Descriptor;
+ OS.flush();
+
+ std::string Expected =
+ "RootSRV(t0, space = 42, visibility = Geometry, flags = None)";
+ EXPECT_EQ(Out, Expected);
+}
+
+TEST(HLSLRootSignatureTest, RootUAVDump) {
+ RootDescriptor Descriptor;
+ Descriptor.Type = DescriptorType::UAV;
+ Descriptor.Reg = {RegisterType::UReg, 92374};
+ Descriptor.Space = 932847;
+ Descriptor.Visibility = ShaderVisibility::Hull;
+ Descriptor.Flags = RootDescriptorFlags::ValidFlags;
+
+ std::string Out;
+ llvm::raw_string_ostream OS(Out);
+ OS << Descriptor;
+ OS.flush();
+
+ std::string Expected =
+ "RootUAV(u92374, space = 932847, visibility = Hull, flags = "
+ "DataVolatile | "
+ "DataStaticWhileSetAtExecute | "
+ "DataStatic)";
+ EXPECT_EQ(Out, Expected);
+}
+
+TEST(HLSLRootSignatureTest, DefaultStaticSamplerDump) {
+ StaticSampler Sampler;
+ Sampler.Reg = {RegisterType::SReg, 0};
+
+ std::string Out;
+ llvm::raw_string_ostream OS(Out);
+ OS << Sampler;
+ OS.flush();
+
+ std::string Expected = "StaticSampler(s0, "
+ "filter = Anisotropic, "
+ "addressU = Wrap, "
+ "addressV = Wrap, "
+ "addressW = Wrap, "
+ "mipLODBias = 0.000000e+00, "
+ "maxAnisotropy = 16, "
+ "comparisonFunc = LessEqual, "
+ "borderColor = OpaqueWhite, "
+ "minLOD = 0.000000e+00, "
+ "maxLOD = 3.402823e+38, "
+ "space = 0, "
+ "visibility = All"
+ ")";
+ EXPECT_EQ(Out, Expected);
+}
+
+TEST(HLSLRootSignatureTest, DefinedStaticSamplerDump) {
+ StaticSampler Sampler;
+ Sampler.Reg = {RegisterType::SReg, 0};
+
+ Sampler.Filter = SamplerFilter::ComparisonMinMagLinearMipPoint;
+ Sampler.AddressU = TextureAddressMode::Mirror;
+ Sampler.AddressV = TextureAddressMode::Border;
+ Sampler.AddressW = TextureAddressMode::Clamp;
+ Sampler.MipLODBias = 4.8f;
+ Sampler.MaxAnisotropy = 32;
+ Sampler.CompFunc = ComparisonFunc::NotEqual;
+ Sampler.BorderColor = StaticBorderColor::OpaqueBlack;
+ Sampler.MinLOD = 1.0f;
+ Sampler.MaxLOD = 32.0f;
+ Sampler.Space = 7;
+ Sampler.Visibility = ShaderVisibility::Domain;
+
+ std::string Out;
+ llvm::raw_string_ostream OS(Out);
+ OS << Sampler;
+ OS.flush();
+
+ std::string Expected = "StaticSampler(s0, "
+ "filter = ComparisonMinMagLinearMipPoint, "
+ "addressU = Mirror, "
+ "addressV = Border, "
+ "addressW = Clamp, "
+ "mipLODBias = 4.800000e+00, "
+ "maxAnisotropy = 32, "
+ "comparisonFunc = NotEqual, "
+ "borderColor = OpaqueBlack, "
+ "minLOD = 1.000000e+00, "
+ "maxLOD = 3.200000e+01, "
+ "space = 7, "
+ "visibility = Domain"
+ ")";
+ EXPECT_EQ(Out, Expected);
+}
+
} // namespace
``````````
</details>
https://github.com/llvm/llvm-project/pull/143198
More information about the llvm-branch-commits
mailing list