[llvm] c21e2a5 - [DirectX] Moving Root Signature Metadata Parsing in to Shared Root Signature Metadata lib (#149221)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 23 17:03:16 PDT 2025


Author: joaosaffran
Date: 2025-07-23T17:03:13-07:00
New Revision: c21e2a5e24c7bc55767055e25bb2fb40cbec68c3

URL: https://github.com/llvm/llvm-project/commit/c21e2a5e24c7bc55767055e25bb2fb40cbec68c3
DIFF: https://github.com/llvm/llvm-project/commit/c21e2a5e24c7bc55767055e25bb2fb40cbec68c3.diff

LOG: [DirectX] Moving Root Signature Metadata Parsing in to Shared Root Signature Metadata lib (#149221)

This PR, moves the existing Root Signature Metadata Parsing logic used
in `DXILRootSignature` to the common library used by both frontend and
backend. Closes:
[#145942](https://github.com/llvm/llvm-project/issues/145942)

---------

Co-authored-by: joaosaffran <joao.saffran at microsoft.com>

Added: 
    

Modified: 
    llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
    llvm/include/llvm/MC/DXContainerRootSignature.h
    llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
    llvm/lib/Target/DirectX/DXILRootSignature.cpp
    llvm/lib/Target/DirectX/DXILRootSignature.h

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h b/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
index 0aa122f668ef1..6fa51eded52f0 100644
--- a/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
+++ b/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
@@ -15,6 +15,8 @@
 #define LLVM_FRONTEND_HLSL_ROOTSIGNATUREMETADATA_H
 
 #include "llvm/Frontend/HLSL/HLSLRootSignature.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/MC/DXContainerRootSignature.h"
 
 namespace llvm {
 class LLVMContext;
@@ -49,6 +51,48 @@ class MetadataBuilder {
   SmallVector<Metadata *> GeneratedMetadata;
 };
 
+enum class RootSignatureElementKind {
+  Error = 0,
+  RootFlags = 1,
+  RootConstants = 2,
+  SRV = 3,
+  UAV = 4,
+  CBV = 5,
+  DescriptorTable = 6,
+  StaticSamplers = 7
+};
+
+class MetadataParser {
+public:
+  MetadataParser(MDNode *Root) : Root(Root) {}
+
+  LLVM_ABI bool ParseRootSignature(LLVMContext *Ctx,
+                                   mcdxbc::RootSignatureDesc &RSD);
+
+private:
+  bool parseRootFlags(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
+                      MDNode *RootFlagNode);
+  bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
+                          MDNode *RootConstantNode);
+  bool parseRootDescriptors(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
+                            MDNode *RootDescriptorNode,
+                            RootSignatureElementKind ElementKind);
+  bool parseDescriptorRange(LLVMContext *Ctx, mcdxbc::DescriptorTable &Table,
+                            MDNode *RangeDescriptorNode);
+  bool parseDescriptorTable(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
+                            MDNode *DescriptorTableNode);
+  bool parseRootSignatureElement(LLVMContext *Ctx,
+                                 mcdxbc::RootSignatureDesc &RSD,
+                                 MDNode *Element);
+  bool parseStaticSampler(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
+                          MDNode *StaticSamplerNode);
+
+  bool validateRootSignature(LLVMContext *Ctx,
+                             const llvm::mcdxbc::RootSignatureDesc &RSD);
+
+  MDNode *Root;
+};
+
 } // namespace rootsig
 } // namespace hlsl
 } // namespace llvm

diff  --git a/llvm/include/llvm/MC/DXContainerRootSignature.h b/llvm/include/llvm/MC/DXContainerRootSignature.h
index 4b6b42f7d74f7..14a2429ffcc78 100644
--- a/llvm/include/llvm/MC/DXContainerRootSignature.h
+++ b/llvm/include/llvm/MC/DXContainerRootSignature.h
@@ -6,6 +6,9 @@
 //
 //===----------------------------------------------------------------------===//
 
+#ifndef LLVM_MC_DXCONTAINERROOTSIGNATURE_H
+#define LLVM_MC_DXCONTAINERROOTSIGNATURE_H
+
 #include "llvm/BinaryFormat/DXContainer.h"
 #include <cstdint>
 #include <limits>
@@ -116,3 +119,5 @@ struct RootSignatureDesc {
 };
 } // namespace mcdxbc
 } // namespace llvm
+
+#endif // LLVM_MC_DXCONTAINERROOTSIGNATURE_H

diff  --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
index f7669f09dcecc..53f59349ae029 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
@@ -12,6 +12,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Frontend/HLSL/RootSignatureMetadata.h"
+#include "llvm/Frontend/HLSL/RootSignatureValidations.h"
+#include "llvm/IR/DiagnosticInfo.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Metadata.h"
 #include "llvm/Support/ScopedPrinter.h"
@@ -20,6 +22,42 @@ namespace llvm {
 namespace hlsl {
 namespace rootsig {
 
+static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
+                                                 unsigned int OpId) {
+  if (auto *CI =
+          mdconst::dyn_extract<ConstantInt>(Node->getOperand(OpId).get()))
+    return CI->getZExtValue();
+  return std::nullopt;
+}
+
+static std::optional<float> extractMdFloatValue(MDNode *Node,
+                                                unsigned int OpId) {
+  if (auto *CI = mdconst::dyn_extract<ConstantFP>(Node->getOperand(OpId).get()))
+    return CI->getValueAPF().convertToFloat();
+  return std::nullopt;
+}
+
+static std::optional<StringRef> extractMdStringValue(MDNode *Node,
+                                                     unsigned int OpId) {
+  MDString *NodeText = dyn_cast<MDString>(Node->getOperand(OpId));
+  if (NodeText == nullptr)
+    return std::nullopt;
+  return NodeText->getString();
+}
+
+static bool reportError(LLVMContext *Ctx, Twine Message,
+                        DiagnosticSeverity Severity = DS_Error) {
+  Ctx->diagnose(DiagnosticInfoGeneric(Message, Severity));
+  return true;
+}
+
+static bool reportValueError(LLVMContext *Ctx, Twine ParamName,
+                             uint32_t Value) {
+  Ctx->diagnose(DiagnosticInfoGeneric(
+      "Invalid value for " + ParamName + ": " + Twine(Value), DS_Error));
+  return true;
+}
+
 static const EnumEntry<dxil::ResourceClass> ResourceClassNames[] = {
     {"CBV", dxil::ResourceClass::CBuffer},
     {"SRV", dxil::ResourceClass::SRV},
@@ -189,6 +227,442 @@ MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) {
   return MDNode::get(Ctx, Operands);
 }
 
+bool MetadataParser::parseRootFlags(LLVMContext *Ctx,
+                                    mcdxbc::RootSignatureDesc &RSD,
+                                    MDNode *RootFlagNode) {
+
+  if (RootFlagNode->getNumOperands() != 2)
+    return reportError(Ctx, "Invalid format for RootFlag Element");
+
+  if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1))
+    RSD.Flags = *Val;
+  else
+    return reportError(Ctx, "Invalid value for RootFlag");
+
+  return false;
+}
+
+bool MetadataParser::parseRootConstants(LLVMContext *Ctx,
+                                        mcdxbc::RootSignatureDesc &RSD,
+                                        MDNode *RootConstantNode) {
+
+  if (RootConstantNode->getNumOperands() != 5)
+    return reportError(Ctx, "Invalid format for RootConstants Element");
+
+  dxbc::RTS0::v1::RootParameterHeader Header;
+  // The parameter offset doesn't matter here - we recalculate it during
+  // serialization  Header.ParameterOffset = 0;
+  Header.ParameterType =
+      llvm::to_underlying(dxbc::RootParameterType::Constants32Bit);
+
+  if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1))
+    Header.ShaderVisibility = *Val;
+  else
+    return reportError(Ctx, "Invalid value for ShaderVisibility");
+
+  dxbc::RTS0::v1::RootConstants Constants;
+  if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2))
+    Constants.ShaderRegister = *Val;
+  else
+    return reportError(Ctx, "Invalid value for ShaderRegister");
+
+  if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3))
+    Constants.RegisterSpace = *Val;
+  else
+    return reportError(Ctx, "Invalid value for RegisterSpace");
+
+  if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4))
+    Constants.Num32BitValues = *Val;
+  else
+    return reportError(Ctx, "Invalid value for Num32BitValues");
+
+  RSD.ParametersContainer.addParameter(Header, Constants);
+
+  return false;
+}
+
+bool MetadataParser::parseRootDescriptors(
+    LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
+    MDNode *RootDescriptorNode, RootSignatureElementKind ElementKind) {
+  assert(ElementKind == RootSignatureElementKind::SRV ||
+         ElementKind == RootSignatureElementKind::UAV ||
+         ElementKind == RootSignatureElementKind::CBV &&
+             "parseRootDescriptors should only be called with RootDescriptor "
+             "element kind.");
+  if (RootDescriptorNode->getNumOperands() != 5)
+    return reportError(Ctx, "Invalid format for Root Descriptor Element");
+
+  dxbc::RTS0::v1::RootParameterHeader Header;
+  switch (ElementKind) {
+  case RootSignatureElementKind::SRV:
+    Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::SRV);
+    break;
+  case RootSignatureElementKind::UAV:
+    Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::UAV);
+    break;
+  case RootSignatureElementKind::CBV:
+    Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::CBV);
+    break;
+  default:
+    llvm_unreachable("invalid Root Descriptor kind");
+    break;
+  }
+
+  if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 1))
+    Header.ShaderVisibility = *Val;
+  else
+    return reportError(Ctx, "Invalid value for ShaderVisibility");
+
+  dxbc::RTS0::v2::RootDescriptor Descriptor;
+  if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2))
+    Descriptor.ShaderRegister = *Val;
+  else
+    return reportError(Ctx, "Invalid value for ShaderRegister");
+
+  if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 3))
+    Descriptor.RegisterSpace = *Val;
+  else
+    return reportError(Ctx, "Invalid value for RegisterSpace");
+
+  if (RSD.Version == 1) {
+    RSD.ParametersContainer.addParameter(Header, Descriptor);
+    return false;
+  }
+  assert(RSD.Version > 1);
+
+  if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 4))
+    Descriptor.Flags = *Val;
+  else
+    return reportError(Ctx, "Invalid value for Root Descriptor Flags");
+
+  RSD.ParametersContainer.addParameter(Header, Descriptor);
+  return false;
+}
+
+bool MetadataParser::parseDescriptorRange(LLVMContext *Ctx,
+                                          mcdxbc::DescriptorTable &Table,
+                                          MDNode *RangeDescriptorNode) {
+
+  if (RangeDescriptorNode->getNumOperands() != 6)
+    return reportError(Ctx, "Invalid format for Descriptor Range");
+
+  dxbc::RTS0::v2::DescriptorRange Range;
+
+  std::optional<StringRef> ElementText =
+      extractMdStringValue(RangeDescriptorNode, 0);
+
+  if (!ElementText.has_value())
+    return reportError(Ctx, "Descriptor Range, first element is not a string.");
+
+  Range.RangeType =
+      StringSwitch<uint32_t>(*ElementText)
+          .Case("CBV", llvm::to_underlying(dxbc::DescriptorRangeType::CBV))
+          .Case("SRV", llvm::to_underlying(dxbc::DescriptorRangeType::SRV))
+          .Case("UAV", llvm::to_underlying(dxbc::DescriptorRangeType::UAV))
+          .Case("Sampler",
+                llvm::to_underlying(dxbc::DescriptorRangeType::Sampler))
+          .Default(~0U);
+
+  if (Range.RangeType == ~0U)
+    return reportError(Ctx, "Invalid Descriptor Range type: " + *ElementText);
+
+  if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 1))
+    Range.NumDescriptors = *Val;
+  else
+    return reportError(Ctx, "Invalid value for Number of Descriptor in Range");
+
+  if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 2))
+    Range.BaseShaderRegister = *Val;
+  else
+    return reportError(Ctx, "Invalid value for BaseShaderRegister");
+
+  if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 3))
+    Range.RegisterSpace = *Val;
+  else
+    return reportError(Ctx, "Invalid value for RegisterSpace");
+
+  if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 4))
+    Range.OffsetInDescriptorsFromTableStart = *Val;
+  else
+    return reportError(Ctx,
+                       "Invalid value for OffsetInDescriptorsFromTableStart");
+
+  if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 5))
+    Range.Flags = *Val;
+  else
+    return reportError(Ctx, "Invalid value for Descriptor Range Flags");
+
+  Table.Ranges.push_back(Range);
+  return false;
+}
+
+bool MetadataParser::parseDescriptorTable(LLVMContext *Ctx,
+                                          mcdxbc::RootSignatureDesc &RSD,
+                                          MDNode *DescriptorTableNode) {
+  const unsigned int NumOperands = DescriptorTableNode->getNumOperands();
+  if (NumOperands < 2)
+    return reportError(Ctx, "Invalid format for Descriptor Table");
+
+  dxbc::RTS0::v1::RootParameterHeader Header;
+  if (std::optional<uint32_t> Val = extractMdIntValue(DescriptorTableNode, 1))
+    Header.ShaderVisibility = *Val;
+  else
+    return reportError(Ctx, "Invalid value for ShaderVisibility");
+
+  mcdxbc::DescriptorTable Table;
+  Header.ParameterType =
+      llvm::to_underlying(dxbc::RootParameterType::DescriptorTable);
+
+  for (unsigned int I = 2; I < NumOperands; I++) {
+    MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I));
+    if (Element == nullptr)
+      return reportError(Ctx, "Missing Root Element Metadata Node.");
+
+    if (parseDescriptorRange(Ctx, Table, Element))
+      return true;
+  }
+
+  RSD.ParametersContainer.addParameter(Header, Table);
+  return false;
+}
+
+bool MetadataParser::parseStaticSampler(LLVMContext *Ctx,
+                                        mcdxbc::RootSignatureDesc &RSD,
+                                        MDNode *StaticSamplerNode) {
+  if (StaticSamplerNode->getNumOperands() != 14)
+    return reportError(Ctx, "Invalid format for Static Sampler");
+
+  dxbc::RTS0::v1::StaticSampler Sampler;
+  if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 1))
+    Sampler.Filter = *Val;
+  else
+    return reportError(Ctx, "Invalid value for Filter");
+
+  if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 2))
+    Sampler.AddressU = *Val;
+  else
+    return reportError(Ctx, "Invalid value for AddressU");
+
+  if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 3))
+    Sampler.AddressV = *Val;
+  else
+    return reportError(Ctx, "Invalid value for AddressV");
+
+  if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 4))
+    Sampler.AddressW = *Val;
+  else
+    return reportError(Ctx, "Invalid value for AddressW");
+
+  if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 5))
+    Sampler.MipLODBias = *Val;
+  else
+    return reportError(Ctx, "Invalid value for MipLODBias");
+
+  if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 6))
+    Sampler.MaxAnisotropy = *Val;
+  else
+    return reportError(Ctx, "Invalid value for MaxAnisotropy");
+
+  if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 7))
+    Sampler.ComparisonFunc = *Val;
+  else
+    return reportError(Ctx, "Invalid value for ComparisonFunc ");
+
+  if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 8))
+    Sampler.BorderColor = *Val;
+  else
+    return reportError(Ctx, "Invalid value for ComparisonFunc ");
+
+  if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 9))
+    Sampler.MinLOD = *Val;
+  else
+    return reportError(Ctx, "Invalid value for MinLOD");
+
+  if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 10))
+    Sampler.MaxLOD = *Val;
+  else
+    return reportError(Ctx, "Invalid value for MaxLOD");
+
+  if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 11))
+    Sampler.ShaderRegister = *Val;
+  else
+    return reportError(Ctx, "Invalid value for ShaderRegister");
+
+  if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 12))
+    Sampler.RegisterSpace = *Val;
+  else
+    return reportError(Ctx, "Invalid value for RegisterSpace");
+
+  if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 13))
+    Sampler.ShaderVisibility = *Val;
+  else
+    return reportError(Ctx, "Invalid value for ShaderVisibility");
+
+  RSD.StaticSamplers.push_back(Sampler);
+  return false;
+}
+
+bool MetadataParser::parseRootSignatureElement(LLVMContext *Ctx,
+                                               mcdxbc::RootSignatureDesc &RSD,
+                                               MDNode *Element) {
+  std::optional<StringRef> ElementText = extractMdStringValue(Element, 0);
+  if (!ElementText.has_value())
+    return reportError(Ctx, "Invalid format for Root Element");
+
+  RootSignatureElementKind ElementKind =
+      StringSwitch<RootSignatureElementKind>(*ElementText)
+          .Case("RootFlags", RootSignatureElementKind::RootFlags)
+          .Case("RootConstants", RootSignatureElementKind::RootConstants)
+          .Case("RootCBV", RootSignatureElementKind::CBV)
+          .Case("RootSRV", RootSignatureElementKind::SRV)
+          .Case("RootUAV", RootSignatureElementKind::UAV)
+          .Case("DescriptorTable", RootSignatureElementKind::DescriptorTable)
+          .Case("StaticSampler", RootSignatureElementKind::StaticSamplers)
+          .Default(RootSignatureElementKind::Error);
+
+  switch (ElementKind) {
+
+  case RootSignatureElementKind::RootFlags:
+    return parseRootFlags(Ctx, RSD, Element);
+  case RootSignatureElementKind::RootConstants:
+    return parseRootConstants(Ctx, RSD, Element);
+  case RootSignatureElementKind::CBV:
+  case RootSignatureElementKind::SRV:
+  case RootSignatureElementKind::UAV:
+    return parseRootDescriptors(Ctx, RSD, Element, ElementKind);
+  case RootSignatureElementKind::DescriptorTable:
+    return parseDescriptorTable(Ctx, RSD, Element);
+  case RootSignatureElementKind::StaticSamplers:
+    return parseStaticSampler(Ctx, RSD, Element);
+  case RootSignatureElementKind::Error:
+    return reportError(Ctx, "Invalid Root Signature Element: " + *ElementText);
+  }
+
+  llvm_unreachable("Unhandled RootSignatureElementKind enum.");
+}
+
+bool MetadataParser::validateRootSignature(
+    LLVMContext *Ctx, const llvm::mcdxbc::RootSignatureDesc &RSD) {
+  if (!llvm::hlsl::rootsig::verifyVersion(RSD.Version)) {
+    return reportValueError(Ctx, "Version", RSD.Version);
+  }
+
+  if (!llvm::hlsl::rootsig::verifyRootFlag(RSD.Flags)) {
+    return reportValueError(Ctx, "RootFlags", RSD.Flags);
+  }
+
+  for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) {
+    if (!dxbc::isValidShaderVisibility(Info.Header.ShaderVisibility))
+      return reportValueError(Ctx, "ShaderVisibility",
+                              Info.Header.ShaderVisibility);
+
+    assert(dxbc::isValidParameterType(Info.Header.ParameterType) &&
+           "Invalid value for ParameterType");
+
+    switch (Info.Header.ParameterType) {
+
+    case llvm::to_underlying(dxbc::RootParameterType::CBV):
+    case llvm::to_underlying(dxbc::RootParameterType::UAV):
+    case llvm::to_underlying(dxbc::RootParameterType::SRV): {
+      const dxbc::RTS0::v2::RootDescriptor &Descriptor =
+          RSD.ParametersContainer.getRootDescriptor(Info.Location);
+      if (!llvm::hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister))
+        return reportValueError(Ctx, "ShaderRegister",
+                                Descriptor.ShaderRegister);
+
+      if (!llvm::hlsl::rootsig::verifyRegisterSpace(Descriptor.RegisterSpace))
+        return reportValueError(Ctx, "RegisterSpace", Descriptor.RegisterSpace);
+
+      if (RSD.Version > 1) {
+        if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(RSD.Version,
+                                                           Descriptor.Flags))
+          return reportValueError(Ctx, "RootDescriptorFlag", Descriptor.Flags);
+      }
+      break;
+    }
+    case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
+      const mcdxbc::DescriptorTable &Table =
+          RSD.ParametersContainer.getDescriptorTable(Info.Location);
+      for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) {
+        if (!llvm::hlsl::rootsig::verifyRangeType(Range.RangeType))
+          return reportValueError(Ctx, "RangeType", Range.RangeType);
+
+        if (!llvm::hlsl::rootsig::verifyRegisterSpace(Range.RegisterSpace))
+          return reportValueError(Ctx, "RegisterSpace", Range.RegisterSpace);
+
+        if (!llvm::hlsl::rootsig::verifyNumDescriptors(Range.NumDescriptors))
+          return reportValueError(Ctx, "NumDescriptors", Range.NumDescriptors);
+
+        if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(
+                RSD.Version, Range.RangeType, Range.Flags))
+          return reportValueError(Ctx, "DescriptorFlag", Range.Flags);
+      }
+      break;
+    }
+    }
+  }
+
+  for (const dxbc::RTS0::v1::StaticSampler &Sampler : RSD.StaticSamplers) {
+    if (!llvm::hlsl::rootsig::verifySamplerFilter(Sampler.Filter))
+      return reportValueError(Ctx, "Filter", Sampler.Filter);
+
+    if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressU))
+      return reportValueError(Ctx, "AddressU", Sampler.AddressU);
+
+    if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressV))
+      return reportValueError(Ctx, "AddressV", Sampler.AddressV);
+
+    if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressW))
+      return reportValueError(Ctx, "AddressW", Sampler.AddressW);
+
+    if (!llvm::hlsl::rootsig::verifyMipLODBias(Sampler.MipLODBias))
+      return reportValueError(Ctx, "MipLODBias", Sampler.MipLODBias);
+
+    if (!llvm::hlsl::rootsig::verifyMaxAnisotropy(Sampler.MaxAnisotropy))
+      return reportValueError(Ctx, "MaxAnisotropy", Sampler.MaxAnisotropy);
+
+    if (!llvm::hlsl::rootsig::verifyComparisonFunc(Sampler.ComparisonFunc))
+      return reportValueError(Ctx, "ComparisonFunc", Sampler.ComparisonFunc);
+
+    if (!llvm::hlsl::rootsig::verifyBorderColor(Sampler.BorderColor))
+      return reportValueError(Ctx, "BorderColor", Sampler.BorderColor);
+
+    if (!llvm::hlsl::rootsig::verifyLOD(Sampler.MinLOD))
+      return reportValueError(Ctx, "MinLOD", Sampler.MinLOD);
+
+    if (!llvm::hlsl::rootsig::verifyLOD(Sampler.MaxLOD))
+      return reportValueError(Ctx, "MaxLOD", Sampler.MaxLOD);
+
+    if (!llvm::hlsl::rootsig::verifyRegisterValue(Sampler.ShaderRegister))
+      return reportValueError(Ctx, "ShaderRegister", Sampler.ShaderRegister);
+
+    if (!llvm::hlsl::rootsig::verifyRegisterSpace(Sampler.RegisterSpace))
+      return reportValueError(Ctx, "RegisterSpace", Sampler.RegisterSpace);
+
+    if (!dxbc::isValidShaderVisibility(Sampler.ShaderVisibility))
+      return reportValueError(Ctx, "ShaderVisibility",
+                              Sampler.ShaderVisibility);
+  }
+
+  return false;
+}
+
+bool MetadataParser::ParseRootSignature(LLVMContext *Ctx,
+                                        mcdxbc::RootSignatureDesc &RSD) {
+  bool HasError = false;
+
+  // Loop through the Root Elements of the root signature.
+  for (const auto &Operand : Root->operands()) {
+    MDNode *Element = dyn_cast<MDNode>(Operand);
+    if (Element == nullptr)
+      return reportError(Ctx, "Missing Root Element Metadata Node.");
+
+    HasError = HasError || parseRootSignatureElement(Ctx, RSD, Element) ||
+               validateRootSignature(Ctx, RSD);
+  }
+
+  return HasError;
+}
 } // namespace rootsig
 } // namespace hlsl
 } // namespace llvm

diff  --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
index dfc81626da01f..ebdfcaa566b51 100644
--- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp
+++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
@@ -16,6 +16,7 @@
 #include "llvm/ADT/Twine.h"
 #include "llvm/Analysis/DXILMetadataAnalysis.h"
 #include "llvm/BinaryFormat/DXContainer.h"
+#include "llvm/Frontend/HLSL/RootSignatureMetadata.h"
 #include "llvm/Frontend/HLSL/RootSignatureValidations.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DiagnosticInfo.h"
@@ -29,25 +30,10 @@
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/raw_ostream.h"
 #include <cstdint>
-#include <optional>
-#include <utility>
 
 using namespace llvm;
 using namespace llvm::dxil;
 
-static bool reportError(LLVMContext *Ctx, Twine Message,
-                        DiagnosticSeverity Severity = DS_Error) {
-  Ctx->diagnose(DiagnosticInfoGeneric(Message, Severity));
-  return true;
-}
-
-static bool reportValueError(LLVMContext *Ctx, Twine ParamName,
-                             uint32_t Value) {
-  Ctx->diagnose(DiagnosticInfoGeneric(
-      "Invalid value for " + ParamName + ": " + Twine(Value), DS_Error));
-  return true;
-}
-
 static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
                                                  unsigned int OpId) {
   if (auto *CI =
@@ -56,453 +42,10 @@ static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
   return std::nullopt;
 }
 
-static std::optional<float> extractMdFloatValue(MDNode *Node,
-                                                unsigned int OpId) {
-  if (auto *CI = mdconst::dyn_extract<ConstantFP>(Node->getOperand(OpId).get()))
-    return CI->getValueAPF().convertToFloat();
-  return std::nullopt;
-}
-
-static std::optional<StringRef> extractMdStringValue(MDNode *Node,
-                                                     unsigned int OpId) {
-  MDString *NodeText = dyn_cast<MDString>(Node->getOperand(OpId));
-  if (NodeText == nullptr)
-    return std::nullopt;
-  return NodeText->getString();
-}
-
-static bool parseRootFlags(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
-                           MDNode *RootFlagNode) {
-
-  if (RootFlagNode->getNumOperands() != 2)
-    return reportError(Ctx, "Invalid format for RootFlag Element");
-
-  if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1))
-    RSD.Flags = *Val;
-  else
-    return reportError(Ctx, "Invalid value for RootFlag");
-
-  return false;
-}
-
-static bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
-                               MDNode *RootConstantNode) {
-
-  if (RootConstantNode->getNumOperands() != 5)
-    return reportError(Ctx, "Invalid format for RootConstants Element");
-
-  dxbc::RTS0::v1::RootParameterHeader Header;
-  // The parameter offset doesn't matter here - we recalculate it during
-  // serialization  Header.ParameterOffset = 0;
-  Header.ParameterType =
-      llvm::to_underlying(dxbc::RootParameterType::Constants32Bit);
-
-  if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1))
-    Header.ShaderVisibility = *Val;
-  else
-    return reportError(Ctx, "Invalid value for ShaderVisibility");
-
-  dxbc::RTS0::v1::RootConstants Constants;
-  if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2))
-    Constants.ShaderRegister = *Val;
-  else
-    return reportError(Ctx, "Invalid value for ShaderRegister");
-
-  if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3))
-    Constants.RegisterSpace = *Val;
-  else
-    return reportError(Ctx, "Invalid value for RegisterSpace");
-
-  if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4))
-    Constants.Num32BitValues = *Val;
-  else
-    return reportError(Ctx, "Invalid value for Num32BitValues");
-
-  RSD.ParametersContainer.addParameter(Header, Constants);
-
-  return false;
-}
-
-static bool parseRootDescriptors(LLVMContext *Ctx,
-                                 mcdxbc::RootSignatureDesc &RSD,
-                                 MDNode *RootDescriptorNode,
-                                 RootSignatureElementKind ElementKind) {
-  assert(ElementKind == RootSignatureElementKind::SRV ||
-         ElementKind == RootSignatureElementKind::UAV ||
-         ElementKind == RootSignatureElementKind::CBV &&
-             "parseRootDescriptors should only be called with RootDescriptor "
-             "element kind.");
-  if (RootDescriptorNode->getNumOperands() != 5)
-    return reportError(Ctx, "Invalid format for Root Descriptor Element");
-
-  dxbc::RTS0::v1::RootParameterHeader Header;
-  switch (ElementKind) {
-  case RootSignatureElementKind::SRV:
-    Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::SRV);
-    break;
-  case RootSignatureElementKind::UAV:
-    Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::UAV);
-    break;
-  case RootSignatureElementKind::CBV:
-    Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::CBV);
-    break;
-  default:
-    llvm_unreachable("invalid Root Descriptor kind");
-    break;
-  }
-
-  if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 1))
-    Header.ShaderVisibility = *Val;
-  else
-    return reportError(Ctx, "Invalid value for ShaderVisibility");
-
-  dxbc::RTS0::v2::RootDescriptor Descriptor;
-  if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2))
-    Descriptor.ShaderRegister = *Val;
-  else
-    return reportError(Ctx, "Invalid value for ShaderRegister");
-
-  if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 3))
-    Descriptor.RegisterSpace = *Val;
-  else
-    return reportError(Ctx, "Invalid value for RegisterSpace");
-
-  if (RSD.Version == 1) {
-    RSD.ParametersContainer.addParameter(Header, Descriptor);
-    return false;
-  }
-  assert(RSD.Version > 1);
-
-  if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 4))
-    Descriptor.Flags = *Val;
-  else
-    return reportError(Ctx, "Invalid value for Root Descriptor Flags");
-
-  RSD.ParametersContainer.addParameter(Header, Descriptor);
-  return false;
-}
-
-static bool parseDescriptorRange(LLVMContext *Ctx,
-                                 mcdxbc::DescriptorTable &Table,
-                                 MDNode *RangeDescriptorNode) {
-
-  if (RangeDescriptorNode->getNumOperands() != 6)
-    return reportError(Ctx, "Invalid format for Descriptor Range");
-
-  dxbc::RTS0::v2::DescriptorRange Range;
-
-  std::optional<StringRef> ElementText =
-      extractMdStringValue(RangeDescriptorNode, 0);
-
-  if (!ElementText.has_value())
-    return reportError(Ctx, "Descriptor Range, first element is not a string.");
-
-  Range.RangeType =
-      StringSwitch<uint32_t>(*ElementText)
-          .Case("CBV", llvm::to_underlying(dxbc::DescriptorRangeType::CBV))
-          .Case("SRV", llvm::to_underlying(dxbc::DescriptorRangeType::SRV))
-          .Case("UAV", llvm::to_underlying(dxbc::DescriptorRangeType::UAV))
-          .Case("Sampler",
-                llvm::to_underlying(dxbc::DescriptorRangeType::Sampler))
-          .Default(~0U);
-
-  if (Range.RangeType == ~0U)
-    return reportError(Ctx, "Invalid Descriptor Range type: " + *ElementText);
-
-  if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 1))
-    Range.NumDescriptors = *Val;
-  else
-    return reportError(Ctx, "Invalid value for Number of Descriptor in Range");
-
-  if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 2))
-    Range.BaseShaderRegister = *Val;
-  else
-    return reportError(Ctx, "Invalid value for BaseShaderRegister");
-
-  if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 3))
-    Range.RegisterSpace = *Val;
-  else
-    return reportError(Ctx, "Invalid value for RegisterSpace");
-
-  if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 4))
-    Range.OffsetInDescriptorsFromTableStart = *Val;
-  else
-    return reportError(Ctx,
-                       "Invalid value for OffsetInDescriptorsFromTableStart");
-
-  if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 5))
-    Range.Flags = *Val;
-  else
-    return reportError(Ctx, "Invalid value for Descriptor Range Flags");
-
-  Table.Ranges.push_back(Range);
-  return false;
-}
-
-static bool parseDescriptorTable(LLVMContext *Ctx,
-                                 mcdxbc::RootSignatureDesc &RSD,
-                                 MDNode *DescriptorTableNode) {
-  const unsigned int NumOperands = DescriptorTableNode->getNumOperands();
-  if (NumOperands < 2)
-    return reportError(Ctx, "Invalid format for Descriptor Table");
-
-  dxbc::RTS0::v1::RootParameterHeader Header;
-  if (std::optional<uint32_t> Val = extractMdIntValue(DescriptorTableNode, 1))
-    Header.ShaderVisibility = *Val;
-  else
-    return reportError(Ctx, "Invalid value for ShaderVisibility");
-
-  mcdxbc::DescriptorTable Table;
-  Header.ParameterType =
-      llvm::to_underlying(dxbc::RootParameterType::DescriptorTable);
-
-  for (unsigned int I = 2; I < NumOperands; I++) {
-    MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I));
-    if (Element == nullptr)
-      return reportError(Ctx, "Missing Root Element Metadata Node.");
-
-    if (parseDescriptorRange(Ctx, Table, Element))
-      return true;
-  }
-
-  RSD.ParametersContainer.addParameter(Header, Table);
-  return false;
-}
-
-static bool parseStaticSampler(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
-                               MDNode *StaticSamplerNode) {
-  if (StaticSamplerNode->getNumOperands() != 14)
-    return reportError(Ctx, "Invalid format for Static Sampler");
-
-  dxbc::RTS0::v1::StaticSampler Sampler;
-  if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 1))
-    Sampler.Filter = *Val;
-  else
-    return reportError(Ctx, "Invalid value for Filter");
-
-  if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 2))
-    Sampler.AddressU = *Val;
-  else
-    return reportError(Ctx, "Invalid value for AddressU");
-
-  if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 3))
-    Sampler.AddressV = *Val;
-  else
-    return reportError(Ctx, "Invalid value for AddressV");
-
-  if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 4))
-    Sampler.AddressW = *Val;
-  else
-    return reportError(Ctx, "Invalid value for AddressW");
-
-  if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 5))
-    Sampler.MipLODBias = *Val;
-  else
-    return reportError(Ctx, "Invalid value for MipLODBias");
-
-  if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 6))
-    Sampler.MaxAnisotropy = *Val;
-  else
-    return reportError(Ctx, "Invalid value for MaxAnisotropy");
-
-  if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 7))
-    Sampler.ComparisonFunc = *Val;
-  else
-    return reportError(Ctx, "Invalid value for ComparisonFunc ");
-
-  if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 8))
-    Sampler.BorderColor = *Val;
-  else
-    return reportError(Ctx, "Invalid value for ComparisonFunc ");
-
-  if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 9))
-    Sampler.MinLOD = *Val;
-  else
-    return reportError(Ctx, "Invalid value for MinLOD");
-
-  if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 10))
-    Sampler.MaxLOD = *Val;
-  else
-    return reportError(Ctx, "Invalid value for MaxLOD");
-
-  if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 11))
-    Sampler.ShaderRegister = *Val;
-  else
-    return reportError(Ctx, "Invalid value for ShaderRegister");
-
-  if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 12))
-    Sampler.RegisterSpace = *Val;
-  else
-    return reportError(Ctx, "Invalid value for RegisterSpace");
-
-  if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 13))
-    Sampler.ShaderVisibility = *Val;
-  else
-    return reportError(Ctx, "Invalid value for ShaderVisibility");
-
-  RSD.StaticSamplers.push_back(Sampler);
-  return false;
-}
-
-static bool parseRootSignatureElement(LLVMContext *Ctx,
-                                      mcdxbc::RootSignatureDesc &RSD,
-                                      MDNode *Element) {
-  std::optional<StringRef> ElementText = extractMdStringValue(Element, 0);
-  if (!ElementText.has_value())
-    return reportError(Ctx, "Invalid format for Root Element");
-
-  RootSignatureElementKind ElementKind =
-      StringSwitch<RootSignatureElementKind>(*ElementText)
-          .Case("RootFlags", RootSignatureElementKind::RootFlags)
-          .Case("RootConstants", RootSignatureElementKind::RootConstants)
-          .Case("RootCBV", RootSignatureElementKind::CBV)
-          .Case("RootSRV", RootSignatureElementKind::SRV)
-          .Case("RootUAV", RootSignatureElementKind::UAV)
-          .Case("DescriptorTable", RootSignatureElementKind::DescriptorTable)
-          .Case("StaticSampler", RootSignatureElementKind::StaticSamplers)
-          .Default(RootSignatureElementKind::Error);
-
-  switch (ElementKind) {
-
-  case RootSignatureElementKind::RootFlags:
-    return parseRootFlags(Ctx, RSD, Element);
-  case RootSignatureElementKind::RootConstants:
-    return parseRootConstants(Ctx, RSD, Element);
-  case RootSignatureElementKind::CBV:
-  case RootSignatureElementKind::SRV:
-  case RootSignatureElementKind::UAV:
-    return parseRootDescriptors(Ctx, RSD, Element, ElementKind);
-  case RootSignatureElementKind::DescriptorTable:
-    return parseDescriptorTable(Ctx, RSD, Element);
-  case RootSignatureElementKind::StaticSamplers:
-    return parseStaticSampler(Ctx, RSD, Element);
-  case RootSignatureElementKind::Error:
-    return reportError(Ctx, "Invalid Root Signature Element: " + *ElementText);
-  }
-
-  llvm_unreachable("Unhandled RootSignatureElementKind enum.");
-}
-
-static bool parse(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
-                  MDNode *Node) {
-  bool HasError = false;
-
-  // Loop through the Root Elements of the root signature.
-  for (const auto &Operand : Node->operands()) {
-    MDNode *Element = dyn_cast<MDNode>(Operand);
-    if (Element == nullptr)
-      return reportError(Ctx, "Missing Root Element Metadata Node.");
-
-    HasError = HasError || parseRootSignatureElement(Ctx, RSD, Element);
-  }
-
-  return HasError;
-}
-
-static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) {
-
-  if (!llvm::hlsl::rootsig::verifyVersion(RSD.Version)) {
-    return reportValueError(Ctx, "Version", RSD.Version);
-  }
-
-  if (!llvm::hlsl::rootsig::verifyRootFlag(RSD.Flags)) {
-    return reportValueError(Ctx, "RootFlags", RSD.Flags);
-  }
-
-  for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) {
-    if (!dxbc::isValidShaderVisibility(Info.Header.ShaderVisibility))
-      return reportValueError(Ctx, "ShaderVisibility",
-                              Info.Header.ShaderVisibility);
-
-    assert(dxbc::isValidParameterType(Info.Header.ParameterType) &&
-           "Invalid value for ParameterType");
-
-    switch (Info.Header.ParameterType) {
-
-    case llvm::to_underlying(dxbc::RootParameterType::CBV):
-    case llvm::to_underlying(dxbc::RootParameterType::UAV):
-    case llvm::to_underlying(dxbc::RootParameterType::SRV): {
-      const dxbc::RTS0::v2::RootDescriptor &Descriptor =
-          RSD.ParametersContainer.getRootDescriptor(Info.Location);
-      if (!llvm::hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister))
-        return reportValueError(Ctx, "ShaderRegister",
-                                Descriptor.ShaderRegister);
-
-      if (!llvm::hlsl::rootsig::verifyRegisterSpace(Descriptor.RegisterSpace))
-        return reportValueError(Ctx, "RegisterSpace", Descriptor.RegisterSpace);
-
-      if (RSD.Version > 1) {
-        if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(RSD.Version,
-                                                           Descriptor.Flags))
-          return reportValueError(Ctx, "RootDescriptorFlag", Descriptor.Flags);
-      }
-      break;
-    }
-    case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
-      const mcdxbc::DescriptorTable &Table =
-          RSD.ParametersContainer.getDescriptorTable(Info.Location);
-      for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) {
-        if (!llvm::hlsl::rootsig::verifyRangeType(Range.RangeType))
-          return reportValueError(Ctx, "RangeType", Range.RangeType);
-
-        if (!llvm::hlsl::rootsig::verifyRegisterSpace(Range.RegisterSpace))
-          return reportValueError(Ctx, "RegisterSpace", Range.RegisterSpace);
-
-        if (!llvm::hlsl::rootsig::verifyNumDescriptors(Range.NumDescriptors))
-          return reportValueError(Ctx, "NumDescriptors", Range.NumDescriptors);
-
-        if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(
-                RSD.Version, Range.RangeType, Range.Flags))
-          return reportValueError(Ctx, "DescriptorFlag", Range.Flags);
-      }
-      break;
-    }
-    }
-  }
-
-  for (const dxbc::RTS0::v1::StaticSampler &Sampler : RSD.StaticSamplers) {
-    if (!llvm::hlsl::rootsig::verifySamplerFilter(Sampler.Filter))
-      return reportValueError(Ctx, "Filter", Sampler.Filter);
-
-    if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressU))
-      return reportValueError(Ctx, "AddressU", Sampler.AddressU);
-
-    if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressV))
-      return reportValueError(Ctx, "AddressV", Sampler.AddressV);
-
-    if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressW))
-      return reportValueError(Ctx, "AddressW", Sampler.AddressW);
-
-    if (!llvm::hlsl::rootsig::verifyMipLODBias(Sampler.MipLODBias))
-      return reportValueError(Ctx, "MipLODBias", Sampler.MipLODBias);
-
-    if (!llvm::hlsl::rootsig::verifyMaxAnisotropy(Sampler.MaxAnisotropy))
-      return reportValueError(Ctx, "MaxAnisotropy", Sampler.MaxAnisotropy);
-
-    if (!llvm::hlsl::rootsig::verifyComparisonFunc(Sampler.ComparisonFunc))
-      return reportValueError(Ctx, "ComparisonFunc", Sampler.ComparisonFunc);
-
-    if (!llvm::hlsl::rootsig::verifyBorderColor(Sampler.BorderColor))
-      return reportValueError(Ctx, "BorderColor", Sampler.BorderColor);
-
-    if (!llvm::hlsl::rootsig::verifyLOD(Sampler.MinLOD))
-      return reportValueError(Ctx, "MinLOD", Sampler.MinLOD);
-
-    if (!llvm::hlsl::rootsig::verifyLOD(Sampler.MaxLOD))
-      return reportValueError(Ctx, "MaxLOD", Sampler.MaxLOD);
-
-    if (!llvm::hlsl::rootsig::verifyRegisterValue(Sampler.ShaderRegister))
-      return reportValueError(Ctx, "ShaderRegister", Sampler.ShaderRegister);
-
-    if (!llvm::hlsl::rootsig::verifyRegisterSpace(Sampler.RegisterSpace))
-      return reportValueError(Ctx, "RegisterSpace", Sampler.RegisterSpace);
-
-    if (!dxbc::isValidShaderVisibility(Sampler.ShaderVisibility))
-      return reportValueError(Ctx, "ShaderVisibility",
-                              Sampler.ShaderVisibility);
-  }
-
-  return false;
+static bool reportError(LLVMContext *Ctx, Twine Message,
+                        DiagnosticSeverity Severity = DS_Error) {
+  Ctx->diagnose(DiagnosticInfoGeneric(Message, Severity));
+  return true;
 }
 
 static SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc>
@@ -584,7 +127,9 @@ analyzeModule(Module &M) {
     // static sampler offset is calculated when writting dxcontainer.
     RSD.StaticSamplersOffset = 0u;
 
-    if (parse(Ctx, RSD, RootElementListNode) || validate(Ctx, RSD)) {
+    hlsl::rootsig::MetadataParser MDParser(RootElementListNode);
+
+    if (MDParser.ParseRootSignature(Ctx, RSD)) {
       return RSDMap;
     }
 

diff  --git a/llvm/lib/Target/DirectX/DXILRootSignature.h b/llvm/lib/Target/DirectX/DXILRootSignature.h
index fc39b38258df8..254b7ff504633 100644
--- a/llvm/lib/Target/DirectX/DXILRootSignature.h
+++ b/llvm/lib/Target/DirectX/DXILRootSignature.h
@@ -26,17 +26,6 @@
 namespace llvm {
 namespace dxil {
 
-enum class RootSignatureElementKind {
-  Error = 0,
-  RootFlags = 1,
-  RootConstants = 2,
-  SRV = 3,
-  UAV = 4,
-  CBV = 5,
-  DescriptorTable = 6,
-  StaticSamplers = 7
-};
-
 class RootSignatureBindingInfo {
 private:
   SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> FuncToRsMap;


        


More information about the llvm-commits mailing list