[llvm] [DirectX][NFC] Refactoring DirectX backend to not use `llvm::to_underlying` in switch cases. (PR #151032)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Aug 6 11:03:51 PDT 2025
https://github.com/joaosaffran updated https://github.com/llvm/llvm-project/pull/151032
>From ea35c412bbfbefb50078bdd21b44156e1e1efe48 Mon Sep 17 00:00:00 2001
From: Joao Saffran <{ID}+{username}@users.noreply.github.com>
Date: Mon, 28 Jul 2025 13:02:34 -0700
Subject: [PATCH 1/5] remove the to_undernlying switch
---
.../Frontend/HLSL/RootSignatureMetadata.cpp | 13 ++++---
.../HLSL/RootSignatureValidations.cpp | 7 ++--
llvm/lib/MC/DXContainerRootSignature.cpp | 36 ++++++++++++-------
llvm/lib/ObjectYAML/DXContainerEmitter.cpp | 28 +++++++++------
llvm/lib/ObjectYAML/DXContainerYAML.cpp | 18 ++++++----
llvm/lib/Target/DirectX/DXILRootSignature.cpp | 14 ++++----
6 files changed, 72 insertions(+), 44 deletions(-)
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
index 53f59349ae029..ed1e49c10818d 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
@@ -559,11 +559,14 @@ bool MetadataParser::validateRootSignature(
assert(dxbc::isValidParameterType(Info.Header.ParameterType) &&
"Invalid value for ParameterType");
- switch (Info.Header.ParameterType) {
+ dxbc::RootParameterType PT =
+ static_cast<dxbc::RootParameterType>(Info.Header.ParameterType);
- case llvm::to_underlying(dxbc::RootParameterType::CBV):
- case llvm::to_underlying(dxbc::RootParameterType::UAV):
- case llvm::to_underlying(dxbc::RootParameterType::SRV): {
+ switch (PT) {
+
+ case dxbc::RootParameterType::CBV:
+ case dxbc::RootParameterType::UAV:
+ case dxbc::RootParameterType::SRV: {
const dxbc::RTS0::v2::RootDescriptor &Descriptor =
RSD.ParametersContainer.getRootDescriptor(Info.Location);
if (!llvm::hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister))
@@ -580,7 +583,7 @@ bool MetadataParser::validateRootSignature(
}
break;
}
- case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
+ case dxbc::RootParameterType::DescriptorTable: {
const mcdxbc::DescriptorTable &Table =
RSD.ParametersContainer.getDescriptorTable(Info.Location);
for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) {
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
index f11c7d2033bfb..e55d3ec15c19d 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
@@ -53,10 +53,9 @@ bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) {
bool verifyRangeType(uint32_t Type) {
switch (Type) {
- case llvm::to_underlying(dxbc::DescriptorRangeType::CBV):
- case llvm::to_underlying(dxbc::DescriptorRangeType::SRV):
- case llvm::to_underlying(dxbc::DescriptorRangeType::UAV):
- case llvm::to_underlying(dxbc::DescriptorRangeType::Sampler):
+#define DESCRIPTOR_RANGE(Num, Val) \
+ case llvm::to_underlying(dxbc::DescriptorRangeType::Val):
+#include "llvm/BinaryFormat/DXContainerConstants.def"
return true;
};
diff --git a/llvm/lib/MC/DXContainerRootSignature.cpp b/llvm/lib/MC/DXContainerRootSignature.cpp
index 482280b5ef289..454a0b88178f5 100644
--- a/llvm/lib/MC/DXContainerRootSignature.cpp
+++ b/llvm/lib/MC/DXContainerRootSignature.cpp
@@ -8,6 +8,7 @@
#include "llvm/MC/DXContainerRootSignature.h"
#include "llvm/ADT/SmallString.h"
+#include "llvm/BinaryFormat/DXContainer.h"
#include "llvm/Support/EndianStream.h"
using namespace llvm;
@@ -35,20 +36,27 @@ size_t RootSignatureDesc::getSize() const {
StaticSamplers.size() * sizeof(dxbc::RTS0::v1::StaticSampler);
for (const RootParameterInfo &I : ParametersContainer) {
- switch (I.Header.ParameterType) {
- case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit):
+ // Invalid parameters are allowed while writing.
+ if (!dxbc::isValidParameterType(I.Header.ParameterType))
+ continue;
+
+ dxbc::RootParameterType PT =
+ static_cast<dxbc::RootParameterType>(I.Header.ParameterType);
+
+ switch (PT) {
+ case dxbc::RootParameterType::Constants32Bit:
Size += sizeof(dxbc::RTS0::v1::RootConstants);
break;
- case llvm::to_underlying(dxbc::RootParameterType::CBV):
- case llvm::to_underlying(dxbc::RootParameterType::SRV):
- case llvm::to_underlying(dxbc::RootParameterType::UAV):
+ case dxbc::RootParameterType::CBV:
+ case dxbc::RootParameterType::SRV:
+ case dxbc::RootParameterType::UAV:
if (Version == 1)
Size += sizeof(dxbc::RTS0::v1::RootDescriptor);
else
Size += sizeof(dxbc::RTS0::v2::RootDescriptor);
break;
- case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable):
+ case dxbc::RootParameterType::DescriptorTable:
const DescriptorTable &Table =
ParametersContainer.getDescriptorTable(I.Location);
@@ -97,8 +105,12 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
for (size_t I = 0; I < NumParameters; ++I) {
rewriteOffsetToCurrentByte(BOS, ParamsOffsets[I]);
const auto &[Type, Loc] = ParametersContainer.getTypeAndLocForParameter(I);
- switch (Type) {
- case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
+ if (!dxbc::isValidParameterType(Type))
+ continue;
+ dxbc::RootParameterType PT = static_cast<dxbc::RootParameterType>(Type);
+
+ switch (PT) {
+ case dxbc::RootParameterType::Constants32Bit: {
const dxbc::RTS0::v1::RootConstants &Constants =
ParametersContainer.getConstant(Loc);
support::endian::write(BOS, Constants.ShaderRegister,
@@ -109,9 +121,9 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
llvm::endianness::little);
break;
}
- case llvm::to_underlying(dxbc::RootParameterType::CBV):
- case llvm::to_underlying(dxbc::RootParameterType::SRV):
- case llvm::to_underlying(dxbc::RootParameterType::UAV): {
+ case dxbc::RootParameterType::CBV:
+ case dxbc::RootParameterType::SRV:
+ case dxbc::RootParameterType::UAV: {
const dxbc::RTS0::v2::RootDescriptor &Descriptor =
ParametersContainer.getRootDescriptor(Loc);
@@ -123,7 +135,7 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
support::endian::write(BOS, Descriptor.Flags, llvm::endianness::little);
break;
}
- case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
+ case dxbc::RootParameterType::DescriptorTable: {
const DescriptorTable &Table =
ParametersContainer.getDescriptorTable(Loc);
support::endian::write(BOS, (uint32_t)Table.Ranges.size(),
diff --git a/llvm/lib/ObjectYAML/DXContainerEmitter.cpp b/llvm/lib/ObjectYAML/DXContainerEmitter.cpp
index 043b575a43b11..35e1c3e3b5953 100644
--- a/llvm/lib/ObjectYAML/DXContainerEmitter.cpp
+++ b/llvm/lib/ObjectYAML/DXContainerEmitter.cpp
@@ -278,8 +278,19 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
dxbc::RTS0::v1::RootParameterHeader Header{L.Header.Type, L.Header.Visibility,
L.Header.Offset};
- switch (L.Header.Type) {
- case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
+ if (!dxbc::isValidParameterType(L.Header.Type)) {
+ // Handling invalid parameter type edge case. We intentionally let
+ // obj2yaml/yaml2obj parse and emit invalid dxcontainer data, in order
+ // for that to be used as a testing tool more effectively.
+ RS.ParametersContainer.addInvalidParameter(Header);
+ continue;
+ }
+
+ dxbc::RootParameterType ParameterType =
+ static_cast<dxbc::RootParameterType>(L.Header.Type);
+
+ switch (ParameterType) {
+ case dxbc::RootParameterType::Constants32Bit: {
const DXContainerYAML::RootConstantsYaml &ConstantYaml =
P.RootSignature->Parameters.getOrInsertConstants(L);
dxbc::RTS0::v1::RootConstants Constants;
@@ -289,9 +300,9 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
RS.ParametersContainer.addParameter(Header, Constants);
break;
}
- case llvm::to_underlying(dxbc::RootParameterType::CBV):
- case llvm::to_underlying(dxbc::RootParameterType::SRV):
- case llvm::to_underlying(dxbc::RootParameterType::UAV): {
+ case dxbc::RootParameterType::CBV:
+ case dxbc::RootParameterType::SRV:
+ case dxbc::RootParameterType::UAV: {
const DXContainerYAML::RootDescriptorYaml &DescriptorYaml =
P.RootSignature->Parameters.getOrInsertDescriptor(L);
@@ -303,7 +314,7 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
RS.ParametersContainer.addParameter(Header, Descriptor);
break;
}
- case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
+ case dxbc::RootParameterType::DescriptorTable: {
const DXContainerYAML::DescriptorTableYaml &TableYaml =
P.RootSignature->Parameters.getOrInsertTable(L);
mcdxbc::DescriptorTable Table;
@@ -323,11 +334,6 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
RS.ParametersContainer.addParameter(Header, Table);
break;
}
- default:
- // Handling invalid parameter type edge case. We intentionally let
- // obj2yaml/yaml2obj parse and emit invalid dxcontainer data, in order
- // for that to be used as a testing tool more effectively.
- RS.ParametersContainer.addInvalidParameter(Header);
}
}
diff --git a/llvm/lib/ObjectYAML/DXContainerYAML.cpp b/llvm/lib/ObjectYAML/DXContainerYAML.cpp
index 263f7bdf37bca..03d5725c21491 100644
--- a/llvm/lib/ObjectYAML/DXContainerYAML.cpp
+++ b/llvm/lib/ObjectYAML/DXContainerYAML.cpp
@@ -424,22 +424,28 @@ void MappingContextTraits<DXContainerYAML::RootParameterLocationYaml,
IO.mapRequired("ParameterType", L.Header.Type);
IO.mapRequired("ShaderVisibility", L.Header.Visibility);
- switch (L.Header.Type) {
- case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
+ if (!dxbc::isValidParameterType(L.Header.Type))
+ return;
+ dxbc::RootParameterType PT =
+ static_cast<dxbc::RootParameterType>(L.Header.Type);
+
+ // We allow ParameterType to be invalid here.
+ switch (PT) {
+ case dxbc::RootParameterType::Constants32Bit: {
DXContainerYAML::RootConstantsYaml &Constants =
S.Parameters.getOrInsertConstants(L);
IO.mapRequired("Constants", Constants);
break;
}
- case llvm::to_underlying(dxbc::RootParameterType::CBV):
- case llvm::to_underlying(dxbc::RootParameterType::SRV):
- case llvm::to_underlying(dxbc::RootParameterType::UAV): {
+ case dxbc::RootParameterType::CBV:
+ case dxbc::RootParameterType::SRV:
+ case dxbc::RootParameterType::UAV: {
DXContainerYAML::RootDescriptorYaml &Descriptor =
S.Parameters.getOrInsertDescriptor(L);
IO.mapRequired("Descriptor", Descriptor);
break;
}
- case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
+ case dxbc::RootParameterType::DescriptorTable: {
DXContainerYAML::DescriptorTableYaml &Table =
S.Parameters.getOrInsertTable(L);
IO.mapRequired("Table", Table);
diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
index ebdfcaa566b51..04c7c77953a5b 100644
--- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp
+++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
@@ -175,8 +175,10 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
OS << "- Parameter Type: " << Type << "\n"
<< " Shader Visibility: " << Header.ShaderVisibility << "\n";
- switch (Type) {
- case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
+ assert(dxbc::isValidParameterType(Type) && "Invalid Parameter Type");
+ dxbc::RootParameterType PT = static_cast<dxbc::RootParameterType>(Type);
+ switch (PT) {
+ case dxbc::RootParameterType::Constants32Bit: {
const dxbc::RTS0::v1::RootConstants &Constants =
RS.ParametersContainer.getConstant(Loc);
OS << " Register Space: " << Constants.RegisterSpace << "\n"
@@ -184,9 +186,9 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
<< " Num 32 Bit Values: " << Constants.Num32BitValues << "\n";
break;
}
- case llvm::to_underlying(dxbc::RootParameterType::CBV):
- case llvm::to_underlying(dxbc::RootParameterType::UAV):
- case llvm::to_underlying(dxbc::RootParameterType::SRV): {
+ case dxbc::RootParameterType::CBV:
+ case dxbc::RootParameterType::UAV:
+ case dxbc::RootParameterType::SRV: {
const dxbc::RTS0::v2::RootDescriptor &Descriptor =
RS.ParametersContainer.getRootDescriptor(Loc);
OS << " Register Space: " << Descriptor.RegisterSpace << "\n"
@@ -195,7 +197,7 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
OS << " Flags: " << Descriptor.Flags << "\n";
break;
}
- case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
+ case dxbc::RootParameterType::DescriptorTable: {
const mcdxbc::DescriptorTable &Table =
RS.ParametersContainer.getDescriptorTable(Loc);
OS << " NumRanges: " << Table.Ranges.size() << "\n";
>From 3c1fc513424cf45e5ab6f3d2fa402036390b5e74 Mon Sep 17 00:00:00 2001
From: Joao Saffran <{ID}+{username}@users.noreply.github.com>
Date: Mon, 28 Jul 2025 13:02:55 -0700
Subject: [PATCH 2/5] remove comment
---
llvm/lib/MC/DXContainerRootSignature.cpp | 1 -
1 file changed, 1 deletion(-)
diff --git a/llvm/lib/MC/DXContainerRootSignature.cpp b/llvm/lib/MC/DXContainerRootSignature.cpp
index 454a0b88178f5..c94a39f80eeb2 100644
--- a/llvm/lib/MC/DXContainerRootSignature.cpp
+++ b/llvm/lib/MC/DXContainerRootSignature.cpp
@@ -36,7 +36,6 @@ size_t RootSignatureDesc::getSize() const {
StaticSamplers.size() * sizeof(dxbc::RTS0::v1::StaticSampler);
for (const RootParameterInfo &I : ParametersContainer) {
- // Invalid parameters are allowed while writing.
if (!dxbc::isValidParameterType(I.Header.ParameterType))
continue;
>From 591d12a0071e43d0f14381028fa47b1b15f99b85 Mon Sep 17 00:00:00 2001
From: Joao Saffran <{ID}+{username}@users.noreply.github.com>
Date: Mon, 28 Jul 2025 13:22:25 -0700
Subject: [PATCH 3/5] remove compiling errors
---
llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
index ed1e49c10818d..742b2c55bcce9 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
@@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//
#include "llvm/Frontend/HLSL/RootSignatureMetadata.h"
+#include "llvm/BinaryFormat/DXContainer.h"
#include "llvm/Frontend/HLSL/RootSignatureValidations.h"
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/IRBuilder.h"
@@ -563,6 +564,9 @@ bool MetadataParser::validateRootSignature(
static_cast<dxbc::RootParameterType>(Info.Header.ParameterType);
switch (PT) {
+ case dxbc::RootParameterType::Constants32Bit:
+ // ToDo: Add proper validation.
+ continue;
case dxbc::RootParameterType::CBV:
case dxbc::RootParameterType::UAV:
>From 0d55d2898854ea0d3aada59e418ba231961d382f Mon Sep 17 00:00:00 2001
From: Joao Saffran <{ID}+{username}@users.noreply.github.com>
Date: Tue, 29 Jul 2025 14:51:15 -0700
Subject: [PATCH 4/5] improve comment, maybe
---
llvm/lib/ObjectYAML/DXContainerYAML.cpp | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/llvm/lib/ObjectYAML/DXContainerYAML.cpp b/llvm/lib/ObjectYAML/DXContainerYAML.cpp
index 03d5725c21491..aca605e099535 100644
--- a/llvm/lib/ObjectYAML/DXContainerYAML.cpp
+++ b/llvm/lib/ObjectYAML/DXContainerYAML.cpp
@@ -424,12 +424,13 @@ void MappingContextTraits<DXContainerYAML::RootParameterLocationYaml,
IO.mapRequired("ParameterType", L.Header.Type);
IO.mapRequired("ShaderVisibility", L.Header.Visibility);
+ // If a parameter type is invalid, we don't have a body to parse.
if (!dxbc::isValidParameterType(L.Header.Type))
return;
+
+ // parses the body of a given root parameter type
dxbc::RootParameterType PT =
static_cast<dxbc::RootParameterType>(L.Header.Type);
-
- // We allow ParameterType to be invalid here.
switch (PT) {
case dxbc::RootParameterType::Constants32Bit: {
DXContainerYAML::RootConstantsYaml &Constants =
>From 25ee6d7c5b92c3d61aad261ff68974a244f8687b Mon Sep 17 00:00:00 2001
From: Joao Saffran <{ID}+{username}@users.noreply.github.com>
Date: Wed, 6 Aug 2025 11:03:14 -0700
Subject: [PATCH 5/5] fixing merge issues
---
llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp | 11 +++++++----
1 file changed, 7 insertions(+), 4 deletions(-)
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
index 4ab70aaf20985..2b50493c21c7f 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
@@ -555,10 +555,13 @@ Error MetadataParser::validateRootSignature(
dxbc::RootParameterType PT =
static_cast<dxbc::RootParameterType>(Info.Header.ParameterType);
-
- case to_underlying(dxbc::RootParameterType::CBV):
- case to_underlying(dxbc::RootParameterType::UAV):
- case to_underlying(dxbc::RootParameterType::SRV): {
+ switch (PT) {
+ case dxbc::RootParameterType::Constants32Bit:
+ // ToDo: Add proper validation.
+ continue;
+ case dxbc::RootParameterType::CBV:
+ case dxbc::RootParameterType::UAV:
+ case dxbc::RootParameterType::SRV: {
const dxbc::RTS0::v2::RootDescriptor &Descriptor =
RSD.ParametersContainer.getRootDescriptor(Info.Location);
if (!hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister))
More information about the llvm-commits
mailing list