[llvm] c6dfbc5 - [DirectX] Refactor RootSignature Backend to remove `to_underlying` from Root Parameter Header (#154249)

via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 25 13:28:11 PDT 2025


Author: joaosaffran
Date: 2025-08-25T16:28:07-04:00
New Revision: c6dfbc5cc7b89637ec7f06d7c0018ef8964c9323

URL: https://github.com/llvm/llvm-project/commit/c6dfbc5cc7b89637ec7f06d7c0018ef8964c9323
DIFF: https://github.com/llvm/llvm-project/commit/c6dfbc5cc7b89637ec7f06d7c0018ef8964c9323.diff

LOG: [DirectX] Refactor RootSignature Backend to remove `to_underlying` from Root Parameter Header (#154249)

This patch is refactoring Root Parameter Header in DX Container backend
to remove the usage of `to_underlying`. This requires some changes:
first, MC Root Signature should not depend on Object/DXContainer.h;
Second, we need to assume data to be valid in scenarios where it was
originally not expected, this made some tests be removed.

Added: 
    

Modified: 
    llvm/include/llvm/MC/DXContainerRootSignature.h
    llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
    llvm/lib/MC/DXContainerRootSignature.cpp
    llvm/lib/ObjectYAML/DXContainerEmitter.cpp
    llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp
    llvm/lib/Target/DirectX/DXILRootSignature.cpp
    llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Parameters.ll

Removed: 
    llvm/test/ObjectYAML/DXContainer/RootSignature-InvalidType.yaml
    llvm/test/ObjectYAML/DXContainer/RootSignature-InvalidVisibility.yaml


################################################################################
diff  --git a/llvm/include/llvm/MC/DXContainerRootSignature.h b/llvm/include/llvm/MC/DXContainerRootSignature.h
index 3c7c886e79fc3..4db3f3458c808 100644
--- a/llvm/include/llvm/MC/DXContainerRootSignature.h
+++ b/llvm/include/llvm/MC/DXContainerRootSignature.h
@@ -20,13 +20,13 @@ class raw_ostream;
 namespace mcdxbc {
 
 struct RootParameterInfo {
-  dxbc::RTS0::v1::RootParameterHeader Header;
+  dxbc::RootParameterType Type;
+  dxbc::ShaderVisibility Visibility;
   size_t Location;
 
-  RootParameterInfo() = default;
-
-  RootParameterInfo(dxbc::RTS0::v1::RootParameterHeader Header, size_t Location)
-      : Header(Header), Location(Location) {}
+  RootParameterInfo(dxbc::RootParameterType Type,
+                    dxbc::ShaderVisibility Visibility, size_t Location)
+      : Type(Type), Visibility(Visibility), Location(Location) {}
 };
 
 struct DescriptorTable {
@@ -46,41 +46,34 @@ struct RootParametersContainer {
   SmallVector<dxbc::RTS0::v2::RootDescriptor> Descriptors;
   SmallVector<DescriptorTable> Tables;
 
-  void addInfo(dxbc::RTS0::v1::RootParameterHeader Header, size_t Location) {
-    ParametersInfo.push_back(RootParameterInfo(Header, Location));
+  void addInfo(dxbc::RootParameterType Type, dxbc::ShaderVisibility Visibility,
+               size_t Location) {
+    ParametersInfo.emplace_back(Type, Visibility, Location);
   }
 
-  void addParameter(dxbc::RTS0::v1::RootParameterHeader Header,
+  void addParameter(dxbc::RootParameterType Type,
+                    dxbc::ShaderVisibility Visibility,
                     dxbc::RTS0::v1::RootConstants Constant) {
-    addInfo(Header, Constants.size());
+    addInfo(Type, Visibility, Constants.size());
     Constants.push_back(Constant);
   }
 
-  void addInvalidParameter(dxbc::RTS0::v1::RootParameterHeader Header) {
-    addInfo(Header, -1);
-  }
-
-  void addParameter(dxbc::RTS0::v1::RootParameterHeader Header,
+  void addParameter(dxbc::RootParameterType Type,
+                    dxbc::ShaderVisibility Visibility,
                     dxbc::RTS0::v2::RootDescriptor Descriptor) {
-    addInfo(Header, Descriptors.size());
+    addInfo(Type, Visibility, Descriptors.size());
     Descriptors.push_back(Descriptor);
   }
 
-  void addParameter(dxbc::RTS0::v1::RootParameterHeader Header,
-                    DescriptorTable Table) {
-    addInfo(Header, Tables.size());
+  void addParameter(dxbc::RootParameterType Type,
+                    dxbc::ShaderVisibility Visibility, DescriptorTable Table) {
+    addInfo(Type, Visibility, Tables.size());
     Tables.push_back(Table);
   }
 
-  std::pair<uint32_t, uint32_t>
-  getTypeAndLocForParameter(uint32_t Location) const {
-    const RootParameterInfo &Info = ParametersInfo[Location];
-    return {Info.Header.ParameterType, Info.Location};
-  }
-
-  const dxbc::RTS0::v1::RootParameterHeader &getHeader(size_t Location) const {
+  const RootParameterInfo &getInfo(uint32_t Location) const {
     const RootParameterInfo &Info = ParametersInfo[Location];
-    return Info.Header;
+    return Info;
   }
 
   const dxbc::RTS0::v1::RootConstants &getConstant(size_t Index) const {

diff  --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
index dece8f197aaf7..70f2646d66c57 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
@@ -52,6 +52,17 @@ static std::optional<StringRef> extractMdStringValue(MDNode *Node,
   return NodeText->getString();
 }
 
+static Expected<dxbc::ShaderVisibility>
+extractShaderVisibility(MDNode *Node, unsigned int OpId) {
+  if (std::optional<uint32_t> Val = extractMdIntValue(Node, OpId)) {
+    if (!dxbc::isValidShaderVisibility(*Val))
+      return make_error<RootSignatureValidationError<uint32_t>>(
+          "ShaderVisibility", *Val);
+    return dxbc::ShaderVisibility(*Val);
+  }
+  return make_error<InvalidRSMetadataValue>("ShaderVisibility");
+}
+
 namespace {
 
 // We use the OverloadVisit with std::visit to ensure the compiler catches if a
@@ -221,15 +232,10 @@ Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
   if (RootConstantNode->getNumOperands() != 5)
     return make_error<InvalidRSMetadataFormat>("RootConstants Element");
 
-  dxbc::RTS0::v1::RootParameterHeader Header;
-  // The parameter offset doesn't matter here - we recalculate it during
-  // serialization  Header.ParameterOffset = 0;
-  Header.ParameterType = to_underlying(dxbc::RootParameterType::Constants32Bit);
-
-  if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1))
-    Header.ShaderVisibility = *Val;
-  else
-    return make_error<InvalidRSMetadataValue>("ShaderVisibility");
+  Expected<dxbc::ShaderVisibility> Visibility =
+      extractShaderVisibility(RootConstantNode, 1);
+  if (auto E = Visibility.takeError())
+    return Error(std::move(E));
 
   dxbc::RTS0::v1::RootConstants Constants;
   if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2))
@@ -247,7 +253,8 @@ Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
   else
     return make_error<InvalidRSMetadataValue>("Num32BitValues");
 
-  RSD.ParametersContainer.addParameter(Header, Constants);
+  RSD.ParametersContainer.addParameter(dxbc::RootParameterType::Constants32Bit,
+                                       *Visibility, Constants);
 
   return Error::success();
 }
@@ -263,26 +270,26 @@ Error MetadataParser::parseRootDescriptors(
   if (RootDescriptorNode->getNumOperands() != 5)
     return make_error<InvalidRSMetadataFormat>("Root Descriptor Element");
 
-  dxbc::RTS0::v1::RootParameterHeader Header;
+  dxbc::RootParameterType Type;
   switch (ElementKind) {
   case RootSignatureElementKind::SRV:
-    Header.ParameterType = to_underlying(dxbc::RootParameterType::SRV);
+    Type = dxbc::RootParameterType::SRV;
     break;
   case RootSignatureElementKind::UAV:
-    Header.ParameterType = to_underlying(dxbc::RootParameterType::UAV);
+    Type = dxbc::RootParameterType::UAV;
     break;
   case RootSignatureElementKind::CBV:
-    Header.ParameterType = to_underlying(dxbc::RootParameterType::CBV);
+    Type = dxbc::RootParameterType::CBV;
     break;
   default:
     llvm_unreachable("invalid Root Descriptor kind");
     break;
   }
 
-  if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 1))
-    Header.ShaderVisibility = *Val;
-  else
-    return make_error<InvalidRSMetadataValue>("ShaderVisibility");
+  Expected<dxbc::ShaderVisibility> Visibility =
+      extractShaderVisibility(RootDescriptorNode, 1);
+  if (auto E = Visibility.takeError())
+    return Error(std::move(E));
 
   dxbc::RTS0::v2::RootDescriptor Descriptor;
   if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2))
@@ -296,7 +303,7 @@ Error MetadataParser::parseRootDescriptors(
     return make_error<InvalidRSMetadataValue>("RegisterSpace");
 
   if (RSD.Version == 1) {
-    RSD.ParametersContainer.addParameter(Header, Descriptor);
+    RSD.ParametersContainer.addParameter(Type, *Visibility, Descriptor);
     return Error::success();
   }
   assert(RSD.Version > 1);
@@ -306,7 +313,7 @@ Error MetadataParser::parseRootDescriptors(
   else
     return make_error<InvalidRSMetadataValue>("Root Descriptor Flags");
 
-  RSD.ParametersContainer.addParameter(Header, Descriptor);
+  RSD.ParametersContainer.addParameter(Type, *Visibility, Descriptor);
   return Error::success();
 }
 
@@ -372,15 +379,12 @@ Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
   if (NumOperands < 2)
     return make_error<InvalidRSMetadataFormat>("Descriptor Table");
 
-  dxbc::RTS0::v1::RootParameterHeader Header;
-  if (std::optional<uint32_t> Val = extractMdIntValue(DescriptorTableNode, 1))
-    Header.ShaderVisibility = *Val;
-  else
-    return make_error<InvalidRSMetadataValue>("ShaderVisibility");
+  Expected<dxbc::ShaderVisibility> Visibility =
+      extractShaderVisibility(DescriptorTableNode, 1);
+  if (auto E = Visibility.takeError())
+    return Error(std::move(E));
 
   mcdxbc::DescriptorTable Table;
-  Header.ParameterType =
-      to_underlying(dxbc::RootParameterType::DescriptorTable);
 
   for (unsigned int I = 2; I < NumOperands; I++) {
     MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I));
@@ -392,7 +396,8 @@ Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
       return Err;
   }
 
-  RSD.ParametersContainer.addParameter(Header, Table);
+  RSD.ParametersContainer.addParameter(dxbc::RootParameterType::DescriptorTable,
+                                       *Visibility, Table);
   return Error::success();
 }
 
@@ -528,20 +533,14 @@ Error MetadataParser::validateRootSignature(
   }
 
   for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) {
-    if (!dxbc::isValidShaderVisibility(Info.Header.ShaderVisibility))
-      DeferredErrs =
-          joinErrors(std::move(DeferredErrs),
-                     make_error<RootSignatureValidationError<uint32_t>>(
-                         "ShaderVisibility", Info.Header.ShaderVisibility));
-
-    assert(dxbc::isValidParameterType(Info.Header.ParameterType) &&
-           "Invalid value for ParameterType");
 
-    switch (Info.Header.ParameterType) {
+    switch (Info.Type) {
+    case dxbc::RootParameterType::Constants32Bit:
+      break;
 
-    case to_underlying(dxbc::RootParameterType::CBV):
-    case to_underlying(dxbc::RootParameterType::UAV):
-    case to_underlying(dxbc::RootParameterType::SRV): {
+    case dxbc::RootParameterType::CBV:
+    case dxbc::RootParameterType::UAV:
+    case dxbc::RootParameterType::SRV: {
       const dxbc::RTS0::v2::RootDescriptor &Descriptor =
           RSD.ParametersContainer.getRootDescriptor(Info.Location);
       if (!hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister))
@@ -566,7 +565,7 @@ Error MetadataParser::validateRootSignature(
       }
       break;
     }
-    case to_underlying(dxbc::RootParameterType::DescriptorTable): {
+    case dxbc::RootParameterType::DescriptorTable: {
       const mcdxbc::DescriptorTable &Table =
           RSD.ParametersContainer.getDescriptorTable(Info.Location);
       for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) {

diff  --git a/llvm/lib/MC/DXContainerRootSignature.cpp b/llvm/lib/MC/DXContainerRootSignature.cpp
index 482280b5ef289..c04dc6bd1800a 100644
--- a/llvm/lib/MC/DXContainerRootSignature.cpp
+++ b/llvm/lib/MC/DXContainerRootSignature.cpp
@@ -35,20 +35,20 @@ size_t RootSignatureDesc::getSize() const {
       StaticSamplers.size() * sizeof(dxbc::RTS0::v1::StaticSampler);
 
   for (const RootParameterInfo &I : ParametersContainer) {
-    switch (I.Header.ParameterType) {
-    case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit):
+    switch (I.Type) {
+    case dxbc::RootParameterType::Constants32Bit:
       Size += sizeof(dxbc::RTS0::v1::RootConstants);
       break;
-    case llvm::to_underlying(dxbc::RootParameterType::CBV):
-    case llvm::to_underlying(dxbc::RootParameterType::SRV):
-    case llvm::to_underlying(dxbc::RootParameterType::UAV):
+    case dxbc::RootParameterType::CBV:
+    case dxbc::RootParameterType::SRV:
+    case dxbc::RootParameterType::UAV:
       if (Version == 1)
         Size += sizeof(dxbc::RTS0::v1::RootDescriptor);
       else
         Size += sizeof(dxbc::RTS0::v2::RootDescriptor);
 
       break;
-    case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable):
+    case dxbc::RootParameterType::DescriptorTable:
       const DescriptorTable &Table =
           ParametersContainer.getDescriptorTable(I.Location);
 
@@ -84,11 +84,9 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
   support::endian::write(BOS, Flags, llvm::endianness::little);
 
   SmallVector<uint32_t> ParamsOffsets;
-  for (const RootParameterInfo &P : ParametersContainer) {
-    support::endian::write(BOS, P.Header.ParameterType,
-                           llvm::endianness::little);
-    support::endian::write(BOS, P.Header.ShaderVisibility,
-                           llvm::endianness::little);
+  for (const RootParameterInfo &I : ParametersContainer) {
+    support::endian::write(BOS, I.Type, llvm::endianness::little);
+    support::endian::write(BOS, I.Visibility, llvm::endianness::little);
 
     ParamsOffsets.push_back(writePlaceholder(BOS));
   }
@@ -96,11 +94,11 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
   assert(NumParameters == ParamsOffsets.size());
   for (size_t I = 0; I < NumParameters; ++I) {
     rewriteOffsetToCurrentByte(BOS, ParamsOffsets[I]);
-    const auto &[Type, Loc] = ParametersContainer.getTypeAndLocForParameter(I);
-    switch (Type) {
-    case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
+    const RootParameterInfo &Info = ParametersContainer.getInfo(I);
+    switch (Info.Type) {
+    case dxbc::RootParameterType::Constants32Bit: {
       const dxbc::RTS0::v1::RootConstants &Constants =
-          ParametersContainer.getConstant(Loc);
+          ParametersContainer.getConstant(Info.Location);
       support::endian::write(BOS, Constants.ShaderRegister,
                              llvm::endianness::little);
       support::endian::write(BOS, Constants.RegisterSpace,
@@ -109,11 +107,11 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
                              llvm::endianness::little);
       break;
     }
-    case llvm::to_underlying(dxbc::RootParameterType::CBV):
-    case llvm::to_underlying(dxbc::RootParameterType::SRV):
-    case llvm::to_underlying(dxbc::RootParameterType::UAV): {
+    case dxbc::RootParameterType::CBV:
+    case dxbc::RootParameterType::SRV:
+    case dxbc::RootParameterType::UAV: {
       const dxbc::RTS0::v2::RootDescriptor &Descriptor =
-          ParametersContainer.getRootDescriptor(Loc);
+          ParametersContainer.getRootDescriptor(Info.Location);
 
       support::endian::write(BOS, Descriptor.ShaderRegister,
                              llvm::endianness::little);
@@ -123,9 +121,9 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
         support::endian::write(BOS, Descriptor.Flags, llvm::endianness::little);
       break;
     }
-    case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
+    case dxbc::RootParameterType::DescriptorTable: {
       const DescriptorTable &Table =
-          ParametersContainer.getDescriptorTable(Loc);
+          ParametersContainer.getDescriptorTable(Info.Location);
       support::endian::write(BOS, (uint32_t)Table.Ranges.size(),
                              llvm::endianness::little);
       rewriteOffsetToCurrentByte(BOS, writePlaceholder(BOS));

diff  --git a/llvm/lib/ObjectYAML/DXContainerEmitter.cpp b/llvm/lib/ObjectYAML/DXContainerEmitter.cpp
index 043b575a43b11..b112c6f21ee5a 100644
--- a/llvm/lib/ObjectYAML/DXContainerEmitter.cpp
+++ b/llvm/lib/ObjectYAML/DXContainerEmitter.cpp
@@ -275,23 +275,30 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
 
       for (DXContainerYAML::RootParameterLocationYaml &L :
            P.RootSignature->Parameters.Locations) {
-        dxbc::RTS0::v1::RootParameterHeader Header{L.Header.Type, L.Header.Visibility,
-                                         L.Header.Offset};
 
-        switch (L.Header.Type) {
-        case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
+        assert(dxbc::isValidParameterType(L.Header.Type) &&
+               "invalid DXContainer YAML");
+        assert(dxbc::isValidShaderVisibility(L.Header.Visibility) &&
+               "invalid DXContainer YAML");
+        dxbc::RootParameterType Type = dxbc::RootParameterType(L.Header.Type);
+        dxbc::ShaderVisibility Visibility =
+            dxbc::ShaderVisibility(L.Header.Visibility);
+
+        switch (Type) {
+        case dxbc::RootParameterType::Constants32Bit: {
           const DXContainerYAML::RootConstantsYaml &ConstantYaml =
               P.RootSignature->Parameters.getOrInsertConstants(L);
           dxbc::RTS0::v1::RootConstants Constants;
+
           Constants.Num32BitValues = ConstantYaml.Num32BitValues;
           Constants.RegisterSpace = ConstantYaml.RegisterSpace;
           Constants.ShaderRegister = ConstantYaml.ShaderRegister;
-          RS.ParametersContainer.addParameter(Header, Constants);
+          RS.ParametersContainer.addParameter(Type, Visibility, Constants);
           break;
         }
-        case llvm::to_underlying(dxbc::RootParameterType::CBV):
-        case llvm::to_underlying(dxbc::RootParameterType::SRV):
-        case llvm::to_underlying(dxbc::RootParameterType::UAV): {
+        case dxbc::RootParameterType::CBV:
+        case dxbc::RootParameterType::SRV:
+        case dxbc::RootParameterType::UAV: {
           const DXContainerYAML::RootDescriptorYaml &DescriptorYaml =
               P.RootSignature->Parameters.getOrInsertDescriptor(L);
 
@@ -300,10 +307,10 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
           Descriptor.ShaderRegister = DescriptorYaml.ShaderRegister;
           if (RS.Version > 1)
             Descriptor.Flags = DescriptorYaml.getEncodedFlags();
-          RS.ParametersContainer.addParameter(Header, Descriptor);
+          RS.ParametersContainer.addParameter(Type, Visibility, Descriptor);
           break;
         }
-        case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
+        case dxbc::RootParameterType::DescriptorTable: {
           const DXContainerYAML::DescriptorTableYaml &TableYaml =
               P.RootSignature->Parameters.getOrInsertTable(L);
           mcdxbc::DescriptorTable Table;
@@ -320,14 +327,9 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
               Range.Flags = R.getEncodedFlags();
             Table.Ranges.push_back(Range);
           }
-          RS.ParametersContainer.addParameter(Header, Table);
+          RS.ParametersContainer.addParameter(Type, Visibility, Table);
           break;
         }
-        default:
-          // Handling invalid parameter type edge case. We intentionally let
-          // obj2yaml/yaml2obj parse and emit invalid dxcontainer data, in order
-          // for that to be used as a testing tool more effectively.
-          RS.ParametersContainer.addInvalidParameter(Header);
         }
       }
 

diff  --git a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp
index be2c7d1ddff3f..fc0afb9a0efdf 100644
--- a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp
+++ b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp
@@ -164,12 +164,11 @@ static void validateRootSignature(Module &M,
 
   for (const mcdxbc::RootParameterInfo &ParamInfo : RSD.ParametersContainer) {
     dxbc::ShaderVisibility ParamVisibility =
-        static_cast<dxbc::ShaderVisibility>(ParamInfo.Header.ShaderVisibility);
+        dxbc::ShaderVisibility(ParamInfo.Visibility);
     if (ParamVisibility != dxbc::ShaderVisibility::All &&
         ParamVisibility != Visibility)
       continue;
-    dxbc::RootParameterType ParamType =
-        static_cast<dxbc::RootParameterType>(ParamInfo.Header.ParameterType);
+    dxbc::RootParameterType ParamType = dxbc::RootParameterType(ParamInfo.Type);
     switch (ParamType) {
     case dxbc::RootParameterType::Constants32Bit: {
       dxbc::RTS0::v1::RootConstants Const =
@@ -185,10 +184,9 @@ static void validateRootSignature(Module &M,
     case dxbc::RootParameterType::CBV: {
       dxbc::RTS0::v2::RootDescriptor Desc =
           RSD.ParametersContainer.getRootDescriptor(ParamInfo.Location);
-      Builder.trackBinding(toResourceClass(static_cast<dxbc::RootParameterType>(
-                               ParamInfo.Header.ParameterType)),
-                           Desc.RegisterSpace, Desc.ShaderRegister,
-                           Desc.ShaderRegister, &ParamInfo);
+      Builder.trackBinding(toResourceClass(ParamInfo.Type), Desc.RegisterSpace,
+                           Desc.ShaderRegister, Desc.ShaderRegister,
+                           &ParamInfo);
 
       break;
     }

diff  --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
index a4f5086c2f428..62037a8272e7c 100644
--- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp
+++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
@@ -24,9 +24,11 @@
 #include "llvm/IR/Metadata.h"
 #include "llvm/IR/Module.h"
 #include "llvm/InitializePasses.h"
+#include "llvm/MC/DXContainerRootSignature.h"
 #include "llvm/Pass.h"
 #include "llvm/Support/Error.h"
 #include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/ScopedPrinter.h"
 #include "llvm/Support/raw_ostream.h"
 #include <cstdint>
 
@@ -171,37 +173,36 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
        << "RootParametersOffset: " << RS.RootParameterOffset << "\n"
        << "NumParameters: " << RS.ParametersContainer.size() << "\n";
     for (size_t I = 0; I < RS.ParametersContainer.size(); I++) {
-      const auto &[Type, Loc] =
-          RS.ParametersContainer.getTypeAndLocForParameter(I);
-      const dxbc::RTS0::v1::RootParameterHeader Header =
-          RS.ParametersContainer.getHeader(I);
-
-      OS << "- Parameter Type: " << Type << "\n"
-         << "  Shader Visibility: " << Header.ShaderVisibility << "\n";
-
-      switch (Type) {
-      case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
+      const mcdxbc::RootParameterInfo &Info = RS.ParametersContainer.getInfo(I);
+
+      OS << "- Parameter Type: "
+         << enumToStringRef(Info.Type, dxbc::getRootParameterTypes()) << "\n"
+         << "  Shader Visibility: "
+         << enumToStringRef(Info.Visibility, dxbc::getShaderVisibility())
+         << "\n";
+      switch (Info.Type) {
+      case dxbc::RootParameterType::Constants32Bit: {
         const dxbc::RTS0::v1::RootConstants &Constants =
-            RS.ParametersContainer.getConstant(Loc);
+            RS.ParametersContainer.getConstant(Info.Location);
         OS << "  Register Space: " << Constants.RegisterSpace << "\n"
            << "  Shader Register: " << Constants.ShaderRegister << "\n"
            << "  Num 32 Bit Values: " << Constants.Num32BitValues << "\n";
         break;
       }
-      case llvm::to_underlying(dxbc::RootParameterType::CBV):
-      case llvm::to_underlying(dxbc::RootParameterType::UAV):
-      case llvm::to_underlying(dxbc::RootParameterType::SRV): {
+      case dxbc::RootParameterType::CBV:
+      case dxbc::RootParameterType::UAV:
+      case dxbc::RootParameterType::SRV: {
         const dxbc::RTS0::v2::RootDescriptor &Descriptor =
-            RS.ParametersContainer.getRootDescriptor(Loc);
+            RS.ParametersContainer.getRootDescriptor(Info.Location);
         OS << "  Register Space: " << Descriptor.RegisterSpace << "\n"
            << "  Shader Register: " << Descriptor.ShaderRegister << "\n";
         if (RS.Version > 1)
           OS << "  Flags: " << Descriptor.Flags << "\n";
         break;
       }
-      case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
+      case dxbc::RootParameterType::DescriptorTable: {
         const mcdxbc::DescriptorTable &Table =
-            RS.ParametersContainer.getDescriptorTable(Loc);
+            RS.ParametersContainer.getDescriptorTable(Info.Location);
         OS << "  NumRanges: " << Table.Ranges.size() << "\n";
 
         for (const dxbc::RTS0::v2::DescriptorRange Range : Table) {

diff  --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Parameters.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Parameters.ll
index 6477ad397c32d..742fea14f5af6 100644
--- a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Parameters.ll
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Parameters.ll
@@ -25,18 +25,18 @@ attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
 ;CHECK-NEXT:  Version: 2
 ;CHECK-NEXT:  RootParametersOffset: 24
 ;CHECK-NEXT:  NumParameters: 3
-;CHECK-NEXT:   - Parameter Type: 1
-;CHECK-NEXT:     Shader Visibility: 0
+;CHECK-NEXT:   - Parameter Type: Constants32Bit
+;CHECK-NEXT:     Shader Visibility: All
 ;CHECK-NEXT:     Register Space: 2
 ;CHECK-NEXT:     Shader Register: 1
 ;CHECK-NEXT:     Num 32 Bit Values: 3
-;CHECK-NEXT:   - Parameter Type: 3
-;CHECK-NEXT:     Shader Visibility: 1
+;CHECK-NEXT:   - Parameter Type: SRV
+;CHECK-NEXT:     Shader Visibility: Vertex
 ;CHECK-NEXT:     Register Space: 5
 ;CHECK-NEXT:     Shader Register: 4
 ;CHECK-NEXT:     Flags: 4
-;CHECK-NEXT:   - Parameter Type: 0
-;CHECK-NEXT:     Shader Visibility: 0
+;CHECK-NEXT:   - Parameter Type: DescriptorTable
+;CHECK-NEXT:     Shader Visibility: All
 ;CHECK-NEXT:     NumRanges: 2
 ;CHECK-NEXT:     - Range Type: 0
 ;CHECK-NEXT:       Register Space: 0

diff  --git a/llvm/test/ObjectYAML/DXContainer/RootSignature-InvalidType.yaml b/llvm/test/ObjectYAML/DXContainer/RootSignature-InvalidType.yaml
deleted file mode 100644
index 091e70789d956..0000000000000
--- a/llvm/test/ObjectYAML/DXContainer/RootSignature-InvalidType.yaml
+++ /dev/null
@@ -1,29 +0,0 @@
-# RUN: yaml2obj %s -o %t
-# RUN: not obj2yaml 2>&1 %t | FileCheck %s -DFILE=%t
-
-# CHECK: Error reading file: [[FILE]]: Invalid value for parameter type
-
-
---- !dxcontainer
-Header:
-  Hash:            [ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 
-                     0x0, 0x0, 0x0, 0x0, 0x0, 0x0 ]
-  Version:
-    Major:           1
-    Minor:           0
-  PartCount:       1
-  PartOffsets:     [ 60 ]
-Parts:
-  - Name:            RTS0
-    Size:            80
-    RootSignature:
-      Version: 2
-      NumRootParameters: 2
-      RootParametersOffset: 24
-      NumStaticSamplers: 0
-      StaticSamplersOffset: 64
-      Parameters:         
-      - ParameterType: 255 # INVALID
-        ShaderVisibility: 2 # Hull
-      AllowInputAssemblerInputLayout: true
-      DenyGeometryShaderRootAccess: true

diff  --git a/llvm/test/ObjectYAML/DXContainer/RootSignature-InvalidVisibility.yaml b/llvm/test/ObjectYAML/DXContainer/RootSignature-InvalidVisibility.yaml
deleted file mode 100644
index 1acaf6e4e08a4..0000000000000
--- a/llvm/test/ObjectYAML/DXContainer/RootSignature-InvalidVisibility.yaml
+++ /dev/null
@@ -1,33 +0,0 @@
-# RUN: yaml2obj %s -o %t
-# RUN: not obj2yaml 2>&1 %t | FileCheck %s -DFILE=%t
-
-# CHECK: Error reading file: [[FILE]]: Invalid value for shader visibility
-
-
---- !dxcontainer
-Header:
-  Hash:            [ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 
-                     0x0, 0x0, 0x0, 0x0, 0x0, 0x0 ]
-  Version:
-    Major:           1
-    Minor:           0
-  PartCount:       1
-  PartOffsets:     [ 60 ]
-Parts:
-  - Name:            RTS0
-    Size:            80
-    RootSignature:
-      Version: 2
-      NumRootParameters: 2
-      RootParametersOffset: 24
-      NumStaticSamplers: 0
-      StaticSamplersOffset: 64
-      Parameters:         
-      - ParameterType: 1 # Constants32Bit
-        ShaderVisibility: 255 # INVALID
-        Constants:
-          Num32BitValues: 21
-          ShaderRegister: 22
-          RegisterSpace: 23   
-      AllowInputAssemblerInputLayout: true
-      DenyGeometryShaderRootAccess: true


        


More information about the llvm-commits mailing list