[llvm] [DirectX][NFC] Refactoring DirectX backend to not use `llvm::to_underlying` in switch cases. (PR #151032)

via llvm-commits llvm-commits at lists.llvm.org
Wed Aug 6 11:03:51 PDT 2025


https://github.com/joaosaffran updated https://github.com/llvm/llvm-project/pull/151032

>From ea35c412bbfbefb50078bdd21b44156e1e1efe48 Mon Sep 17 00:00:00 2001
From: Joao Saffran <{ID}+{username}@users.noreply.github.com>
Date: Mon, 28 Jul 2025 13:02:34 -0700
Subject: [PATCH 1/5] remove the to_undernlying switch

---
 .../Frontend/HLSL/RootSignatureMetadata.cpp   | 13 ++++---
 .../HLSL/RootSignatureValidations.cpp         |  7 ++--
 llvm/lib/MC/DXContainerRootSignature.cpp      | 36 ++++++++++++-------
 llvm/lib/ObjectYAML/DXContainerEmitter.cpp    | 28 +++++++++------
 llvm/lib/ObjectYAML/DXContainerYAML.cpp       | 18 ++++++----
 llvm/lib/Target/DirectX/DXILRootSignature.cpp | 14 ++++----
 6 files changed, 72 insertions(+), 44 deletions(-)

diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
index 53f59349ae029..ed1e49c10818d 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
@@ -559,11 +559,14 @@ bool MetadataParser::validateRootSignature(
     assert(dxbc::isValidParameterType(Info.Header.ParameterType) &&
            "Invalid value for ParameterType");
 
-    switch (Info.Header.ParameterType) {
+    dxbc::RootParameterType PT =
+        static_cast<dxbc::RootParameterType>(Info.Header.ParameterType);
 
-    case llvm::to_underlying(dxbc::RootParameterType::CBV):
-    case llvm::to_underlying(dxbc::RootParameterType::UAV):
-    case llvm::to_underlying(dxbc::RootParameterType::SRV): {
+    switch (PT) {
+
+    case dxbc::RootParameterType::CBV:
+    case dxbc::RootParameterType::UAV:
+    case dxbc::RootParameterType::SRV: {
       const dxbc::RTS0::v2::RootDescriptor &Descriptor =
           RSD.ParametersContainer.getRootDescriptor(Info.Location);
       if (!llvm::hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister))
@@ -580,7 +583,7 @@ bool MetadataParser::validateRootSignature(
       }
       break;
     }
-    case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
+    case dxbc::RootParameterType::DescriptorTable: {
       const mcdxbc::DescriptorTable &Table =
           RSD.ParametersContainer.getDescriptorTable(Info.Location);
       for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) {
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
index f11c7d2033bfb..e55d3ec15c19d 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
@@ -53,10 +53,9 @@ bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) {
 
 bool verifyRangeType(uint32_t Type) {
   switch (Type) {
-  case llvm::to_underlying(dxbc::DescriptorRangeType::CBV):
-  case llvm::to_underlying(dxbc::DescriptorRangeType::SRV):
-  case llvm::to_underlying(dxbc::DescriptorRangeType::UAV):
-  case llvm::to_underlying(dxbc::DescriptorRangeType::Sampler):
+#define DESCRIPTOR_RANGE(Num, Val)                                             \
+  case llvm::to_underlying(dxbc::DescriptorRangeType::Val):
+#include "llvm/BinaryFormat/DXContainerConstants.def"
     return true;
   };
 
diff --git a/llvm/lib/MC/DXContainerRootSignature.cpp b/llvm/lib/MC/DXContainerRootSignature.cpp
index 482280b5ef289..454a0b88178f5 100644
--- a/llvm/lib/MC/DXContainerRootSignature.cpp
+++ b/llvm/lib/MC/DXContainerRootSignature.cpp
@@ -8,6 +8,7 @@
 
 #include "llvm/MC/DXContainerRootSignature.h"
 #include "llvm/ADT/SmallString.h"
+#include "llvm/BinaryFormat/DXContainer.h"
 #include "llvm/Support/EndianStream.h"
 
 using namespace llvm;
@@ -35,20 +36,27 @@ size_t RootSignatureDesc::getSize() const {
       StaticSamplers.size() * sizeof(dxbc::RTS0::v1::StaticSampler);
 
   for (const RootParameterInfo &I : ParametersContainer) {
-    switch (I.Header.ParameterType) {
-    case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit):
+    // Invalid parameters are allowed while writing.
+    if (!dxbc::isValidParameterType(I.Header.ParameterType))
+      continue;
+
+    dxbc::RootParameterType PT =
+        static_cast<dxbc::RootParameterType>(I.Header.ParameterType);
+
+    switch (PT) {
+    case dxbc::RootParameterType::Constants32Bit:
       Size += sizeof(dxbc::RTS0::v1::RootConstants);
       break;
-    case llvm::to_underlying(dxbc::RootParameterType::CBV):
-    case llvm::to_underlying(dxbc::RootParameterType::SRV):
-    case llvm::to_underlying(dxbc::RootParameterType::UAV):
+    case dxbc::RootParameterType::CBV:
+    case dxbc::RootParameterType::SRV:
+    case dxbc::RootParameterType::UAV:
       if (Version == 1)
         Size += sizeof(dxbc::RTS0::v1::RootDescriptor);
       else
         Size += sizeof(dxbc::RTS0::v2::RootDescriptor);
 
       break;
-    case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable):
+    case dxbc::RootParameterType::DescriptorTable:
       const DescriptorTable &Table =
           ParametersContainer.getDescriptorTable(I.Location);
 
@@ -97,8 +105,12 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
   for (size_t I = 0; I < NumParameters; ++I) {
     rewriteOffsetToCurrentByte(BOS, ParamsOffsets[I]);
     const auto &[Type, Loc] = ParametersContainer.getTypeAndLocForParameter(I);
-    switch (Type) {
-    case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
+    if (!dxbc::isValidParameterType(Type))
+      continue;
+    dxbc::RootParameterType PT = static_cast<dxbc::RootParameterType>(Type);
+
+    switch (PT) {
+    case dxbc::RootParameterType::Constants32Bit: {
       const dxbc::RTS0::v1::RootConstants &Constants =
           ParametersContainer.getConstant(Loc);
       support::endian::write(BOS, Constants.ShaderRegister,
@@ -109,9 +121,9 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
                              llvm::endianness::little);
       break;
     }
-    case llvm::to_underlying(dxbc::RootParameterType::CBV):
-    case llvm::to_underlying(dxbc::RootParameterType::SRV):
-    case llvm::to_underlying(dxbc::RootParameterType::UAV): {
+    case dxbc::RootParameterType::CBV:
+    case dxbc::RootParameterType::SRV:
+    case dxbc::RootParameterType::UAV: {
       const dxbc::RTS0::v2::RootDescriptor &Descriptor =
           ParametersContainer.getRootDescriptor(Loc);
 
@@ -123,7 +135,7 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
         support::endian::write(BOS, Descriptor.Flags, llvm::endianness::little);
       break;
     }
-    case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
+    case dxbc::RootParameterType::DescriptorTable: {
       const DescriptorTable &Table =
           ParametersContainer.getDescriptorTable(Loc);
       support::endian::write(BOS, (uint32_t)Table.Ranges.size(),
diff --git a/llvm/lib/ObjectYAML/DXContainerEmitter.cpp b/llvm/lib/ObjectYAML/DXContainerEmitter.cpp
index 043b575a43b11..35e1c3e3b5953 100644
--- a/llvm/lib/ObjectYAML/DXContainerEmitter.cpp
+++ b/llvm/lib/ObjectYAML/DXContainerEmitter.cpp
@@ -278,8 +278,19 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
         dxbc::RTS0::v1::RootParameterHeader Header{L.Header.Type, L.Header.Visibility,
                                          L.Header.Offset};
 
-        switch (L.Header.Type) {
-        case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
+        if (!dxbc::isValidParameterType(L.Header.Type)) {
+          // Handling invalid parameter type edge case. We intentionally let
+          // obj2yaml/yaml2obj parse and emit invalid dxcontainer data, in order
+          // for that to be used as a testing tool more effectively.
+          RS.ParametersContainer.addInvalidParameter(Header);
+          continue;
+        }
+
+        dxbc::RootParameterType ParameterType =
+            static_cast<dxbc::RootParameterType>(L.Header.Type);
+
+        switch (ParameterType) {
+        case dxbc::RootParameterType::Constants32Bit: {
           const DXContainerYAML::RootConstantsYaml &ConstantYaml =
               P.RootSignature->Parameters.getOrInsertConstants(L);
           dxbc::RTS0::v1::RootConstants Constants;
@@ -289,9 +300,9 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
           RS.ParametersContainer.addParameter(Header, Constants);
           break;
         }
-        case llvm::to_underlying(dxbc::RootParameterType::CBV):
-        case llvm::to_underlying(dxbc::RootParameterType::SRV):
-        case llvm::to_underlying(dxbc::RootParameterType::UAV): {
+        case dxbc::RootParameterType::CBV:
+        case dxbc::RootParameterType::SRV:
+        case dxbc::RootParameterType::UAV: {
           const DXContainerYAML::RootDescriptorYaml &DescriptorYaml =
               P.RootSignature->Parameters.getOrInsertDescriptor(L);
 
@@ -303,7 +314,7 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
           RS.ParametersContainer.addParameter(Header, Descriptor);
           break;
         }
-        case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
+        case dxbc::RootParameterType::DescriptorTable: {
           const DXContainerYAML::DescriptorTableYaml &TableYaml =
               P.RootSignature->Parameters.getOrInsertTable(L);
           mcdxbc::DescriptorTable Table;
@@ -323,11 +334,6 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
           RS.ParametersContainer.addParameter(Header, Table);
           break;
         }
-        default:
-          // Handling invalid parameter type edge case. We intentionally let
-          // obj2yaml/yaml2obj parse and emit invalid dxcontainer data, in order
-          // for that to be used as a testing tool more effectively.
-          RS.ParametersContainer.addInvalidParameter(Header);
         }
       }
 
diff --git a/llvm/lib/ObjectYAML/DXContainerYAML.cpp b/llvm/lib/ObjectYAML/DXContainerYAML.cpp
index 263f7bdf37bca..03d5725c21491 100644
--- a/llvm/lib/ObjectYAML/DXContainerYAML.cpp
+++ b/llvm/lib/ObjectYAML/DXContainerYAML.cpp
@@ -424,22 +424,28 @@ void MappingContextTraits<DXContainerYAML::RootParameterLocationYaml,
   IO.mapRequired("ParameterType", L.Header.Type);
   IO.mapRequired("ShaderVisibility", L.Header.Visibility);
 
-  switch (L.Header.Type) {
-  case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
+  if (!dxbc::isValidParameterType(L.Header.Type))
+    return;
+  dxbc::RootParameterType PT =
+      static_cast<dxbc::RootParameterType>(L.Header.Type);
+
+  // We allow ParameterType to be invalid here.
+  switch (PT) {
+  case dxbc::RootParameterType::Constants32Bit: {
     DXContainerYAML::RootConstantsYaml &Constants =
         S.Parameters.getOrInsertConstants(L);
     IO.mapRequired("Constants", Constants);
     break;
   }
-  case llvm::to_underlying(dxbc::RootParameterType::CBV):
-  case llvm::to_underlying(dxbc::RootParameterType::SRV):
-  case llvm::to_underlying(dxbc::RootParameterType::UAV): {
+  case dxbc::RootParameterType::CBV:
+  case dxbc::RootParameterType::SRV:
+  case dxbc::RootParameterType::UAV: {
     DXContainerYAML::RootDescriptorYaml &Descriptor =
         S.Parameters.getOrInsertDescriptor(L);
     IO.mapRequired("Descriptor", Descriptor);
     break;
   }
-  case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
+  case dxbc::RootParameterType::DescriptorTable: {
     DXContainerYAML::DescriptorTableYaml &Table =
         S.Parameters.getOrInsertTable(L);
     IO.mapRequired("Table", Table);
diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
index ebdfcaa566b51..04c7c77953a5b 100644
--- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp
+++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
@@ -175,8 +175,10 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
       OS << "- Parameter Type: " << Type << "\n"
          << "  Shader Visibility: " << Header.ShaderVisibility << "\n";
 
-      switch (Type) {
-      case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
+      assert(dxbc::isValidParameterType(Type) && "Invalid Parameter Type");
+      dxbc::RootParameterType PT = static_cast<dxbc::RootParameterType>(Type);
+      switch (PT) {
+      case dxbc::RootParameterType::Constants32Bit: {
         const dxbc::RTS0::v1::RootConstants &Constants =
             RS.ParametersContainer.getConstant(Loc);
         OS << "  Register Space: " << Constants.RegisterSpace << "\n"
@@ -184,9 +186,9 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
            << "  Num 32 Bit Values: " << Constants.Num32BitValues << "\n";
         break;
       }
-      case llvm::to_underlying(dxbc::RootParameterType::CBV):
-      case llvm::to_underlying(dxbc::RootParameterType::UAV):
-      case llvm::to_underlying(dxbc::RootParameterType::SRV): {
+      case dxbc::RootParameterType::CBV:
+      case dxbc::RootParameterType::UAV:
+      case dxbc::RootParameterType::SRV: {
         const dxbc::RTS0::v2::RootDescriptor &Descriptor =
             RS.ParametersContainer.getRootDescriptor(Loc);
         OS << "  Register Space: " << Descriptor.RegisterSpace << "\n"
@@ -195,7 +197,7 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
           OS << "  Flags: " << Descriptor.Flags << "\n";
         break;
       }
-      case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
+      case dxbc::RootParameterType::DescriptorTable: {
         const mcdxbc::DescriptorTable &Table =
             RS.ParametersContainer.getDescriptorTable(Loc);
         OS << "  NumRanges: " << Table.Ranges.size() << "\n";

>From 3c1fc513424cf45e5ab6f3d2fa402036390b5e74 Mon Sep 17 00:00:00 2001
From: Joao Saffran <{ID}+{username}@users.noreply.github.com>
Date: Mon, 28 Jul 2025 13:02:55 -0700
Subject: [PATCH 2/5] remove comment

---
 llvm/lib/MC/DXContainerRootSignature.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/llvm/lib/MC/DXContainerRootSignature.cpp b/llvm/lib/MC/DXContainerRootSignature.cpp
index 454a0b88178f5..c94a39f80eeb2 100644
--- a/llvm/lib/MC/DXContainerRootSignature.cpp
+++ b/llvm/lib/MC/DXContainerRootSignature.cpp
@@ -36,7 +36,6 @@ size_t RootSignatureDesc::getSize() const {
       StaticSamplers.size() * sizeof(dxbc::RTS0::v1::StaticSampler);
 
   for (const RootParameterInfo &I : ParametersContainer) {
-    // Invalid parameters are allowed while writing.
     if (!dxbc::isValidParameterType(I.Header.ParameterType))
       continue;
 

>From 591d12a0071e43d0f14381028fa47b1b15f99b85 Mon Sep 17 00:00:00 2001
From: Joao Saffran <{ID}+{username}@users.noreply.github.com>
Date: Mon, 28 Jul 2025 13:22:25 -0700
Subject: [PATCH 3/5] remove compiling errors

---
 llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
index ed1e49c10818d..742b2c55bcce9 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Frontend/HLSL/RootSignatureMetadata.h"
+#include "llvm/BinaryFormat/DXContainer.h"
 #include "llvm/Frontend/HLSL/RootSignatureValidations.h"
 #include "llvm/IR/DiagnosticInfo.h"
 #include "llvm/IR/IRBuilder.h"
@@ -563,6 +564,9 @@ bool MetadataParser::validateRootSignature(
         static_cast<dxbc::RootParameterType>(Info.Header.ParameterType);
 
     switch (PT) {
+    case dxbc::RootParameterType::Constants32Bit:
+      // ToDo: Add proper validation.
+      continue;
 
     case dxbc::RootParameterType::CBV:
     case dxbc::RootParameterType::UAV:

>From 0d55d2898854ea0d3aada59e418ba231961d382f Mon Sep 17 00:00:00 2001
From: Joao Saffran <{ID}+{username}@users.noreply.github.com>
Date: Tue, 29 Jul 2025 14:51:15 -0700
Subject: [PATCH 4/5] improve comment, maybe

---
 llvm/lib/ObjectYAML/DXContainerYAML.cpp | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/ObjectYAML/DXContainerYAML.cpp b/llvm/lib/ObjectYAML/DXContainerYAML.cpp
index 03d5725c21491..aca605e099535 100644
--- a/llvm/lib/ObjectYAML/DXContainerYAML.cpp
+++ b/llvm/lib/ObjectYAML/DXContainerYAML.cpp
@@ -424,12 +424,13 @@ void MappingContextTraits<DXContainerYAML::RootParameterLocationYaml,
   IO.mapRequired("ParameterType", L.Header.Type);
   IO.mapRequired("ShaderVisibility", L.Header.Visibility);
 
+  // If a parameter type is invalid, we don't have a body to parse.
   if (!dxbc::isValidParameterType(L.Header.Type))
     return;
+
+  // parses the body of a given root parameter type
   dxbc::RootParameterType PT =
       static_cast<dxbc::RootParameterType>(L.Header.Type);
-
-  // We allow ParameterType to be invalid here.
   switch (PT) {
   case dxbc::RootParameterType::Constants32Bit: {
     DXContainerYAML::RootConstantsYaml &Constants =

>From 25ee6d7c5b92c3d61aad261ff68974a244f8687b Mon Sep 17 00:00:00 2001
From: Joao Saffran <{ID}+{username}@users.noreply.github.com>
Date: Wed, 6 Aug 2025 11:03:14 -0700
Subject: [PATCH 5/5] fixing merge issues

---
 llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp | 11 +++++++----
 1 file changed, 7 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
index 4ab70aaf20985..2b50493c21c7f 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
@@ -555,10 +555,13 @@ Error MetadataParser::validateRootSignature(
 
     dxbc::RootParameterType PT =
         static_cast<dxbc::RootParameterType>(Info.Header.ParameterType);
-
-    case to_underlying(dxbc::RootParameterType::CBV):
-    case to_underlying(dxbc::RootParameterType::UAV):
-    case to_underlying(dxbc::RootParameterType::SRV): {
+    switch (PT) {
+    case dxbc::RootParameterType::Constants32Bit:
+      // ToDo: Add proper validation.
+      continue;
+    case dxbc::RootParameterType::CBV:
+    case dxbc::RootParameterType::UAV:
+    case dxbc::RootParameterType::SRV: {
       const dxbc::RTS0::v2::RootDescriptor &Descriptor =
           RSD.ParametersContainer.getRootDescriptor(Info.Location);
       if (!hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister))



More information about the llvm-commits mailing list