[llvm] [DirectX] Error handling improve in root signature metadata Parser (PR #149232)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Jul 23 18:01:23 PDT 2025
https://github.com/joaosaffran updated https://github.com/llvm/llvm-project/pull/149232
>From 72de785e09281cae8f5eb2a0fa09770a87f3273f Mon Sep 17 00:00:00 2001
From: joaosaffran <joao.saffran at microsoft.com>
Date: Tue, 15 Jul 2025 21:27:04 +0000
Subject: [PATCH 01/10] refactoring init
---
.../Frontend/HLSL/RootSignatureMetadata.h | 35 ++
.../Frontend/HLSL/RootSignatureMetadata.cpp | 336 ++++++++++++++++++
llvm/lib/Target/DirectX/DXILRootSignature.h | 11 +-
3 files changed, 372 insertions(+), 10 deletions(-)
diff --git a/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h b/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
index 0aa122f668ef1..5aa6c30491025 100644
--- a/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
+++ b/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
@@ -15,6 +15,10 @@
#define LLVM_FRONTEND_HLSL_ROOTSIGNATUREMETADATA_H
#include "llvm/Frontend/HLSL/HLSLRootSignature.h"
+#include "llvm/MC/DXContainerRootSignature.h"
+#include "llvm/IR/Function.h"
+#include "llvm/Support/Error.h"
+#include <unordered_map>
namespace llvm {
class LLVMContext;
@@ -49,6 +53,37 @@ class MetadataBuilder {
SmallVector<Metadata *> GeneratedMetadata;
};
+enum class RootSignatureElementKind {
+ Error = 0,
+ RootFlags = 1,
+ RootConstants = 2,
+ SRV = 3,
+ UAV = 4,
+ CBV = 5,
+ DescriptorTable = 6,
+ StaticSamplers = 7
+};
+
+class MetadataParser {
+public:
+ using MapT = SmallDenseMap<const Function *, llvm::mcdxbc::RootSignatureDesc>;
+ MetadataParser(llvm::LLVMContext &Ctx, MDNode* Root): Ctx(Ctx), Root(Root) {}
+
+ /// Iterates through root signature and converts them into MapT
+ LLVM_ABI llvm::Expected<MapT*> ParseRootSignature();
+
+private:
+ llvm::Error parseRootFlags(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, MDNode *RootFlagNode);
+ llvm::Error parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, MDNode *RootConstantNode);
+ llvm::Error parseRootDescriptors(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, MDNode *RootDescriptorNode, RootSignatureElementKind ElementKind);
+ llvm::Error parseDescriptorRange(LLVMContext *Ctx, mcdxbc::DescriptorTable &Table, MDNode *RangeDescriptorNode);
+ llvm::Error parseDescriptorTable(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, MDNode *DescriptorTableNode);
+ llvm::Error parseRootSignatureElement(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, MDNode *Element);
+ llvm::Error parseStaticSampler(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, MDNode *StaticSamplerNode);
+ llvm::LLVMContext &Ctx;
+ MDNode* Root;
+};
+
} // namespace rootsig
} // namespace hlsl
} // namespace llvm
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
index f7669f09dcecc..2ae305dfb0f86 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
@@ -12,8 +12,10 @@
//===----------------------------------------------------------------------===//
#include "llvm/Frontend/HLSL/RootSignatureMetadata.h"
+#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Metadata.h"
+#include "llvm/Support/Error.h"
#include "llvm/Support/ScopedPrinter.h"
namespace llvm {
@@ -34,6 +36,42 @@ static std::optional<StringRef> getResourceName(dxil::ResourceClass Class) {
return std::nullopt;
}
+static bool reportError(LLVMContext *Ctx, Twine Message,
+ DiagnosticSeverity Severity = DS_Error) {
+ Ctx->diagnose(DiagnosticInfoGeneric(Message, Severity));
+ return true;
+}
+
+static bool reportValueError(LLVMContext *Ctx, Twine ParamName,
+ uint32_t Value) {
+ Ctx->diagnose(DiagnosticInfoGeneric(
+ "Invalid value for " + ParamName + ": " + Twine(Value), DS_Error));
+ return true;
+}
+
+static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
+ unsigned int OpId) {
+ if (auto *CI =
+ mdconst::dyn_extract<ConstantInt>(Node->getOperand(OpId).get()))
+ return CI->getZExtValue();
+ return std::nullopt;
+}
+
+static std::optional<float> extractMdFloatValue(MDNode *Node,
+ unsigned int OpId) {
+ if (auto *CI = mdconst::dyn_extract<ConstantFP>(Node->getOperand(OpId).get()))
+ return CI->getValueAPF().convertToFloat();
+ return std::nullopt;
+}
+
+static std::optional<StringRef> extractMdStringValue(MDNode *Node,
+ unsigned int OpId) {
+ MDString *NodeText = dyn_cast<MDString>(Node->getOperand(OpId));
+ if (NodeText == nullptr)
+ return std::nullopt;
+ return NodeText->getString();
+}
+
namespace {
// We use the OverloadVisit with std::visit to ensure the compiler catches if a
@@ -189,6 +227,304 @@ MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) {
return MDNode::get(Ctx, Operands);
}
+llvm::Error MetadataParser::parseRootFlags(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, MDNode *RootFlagNode){
+ if (RootFlagNode->getNumOperands() != 2)
+ return reportError(Ctx, "Invalid format for RootFlag Element");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1))
+ RSD.Flags = *Val;
+ else
+ return reportError(Ctx, "Invalid value for RootFlag");
+
+ return false;
+}
+
+llvm::Error MetadataParser::parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, MDNode *RootConstantNode){
+if (RootConstantNode->getNumOperands() != 5)
+ return reportError(Ctx, "Invalid format for RootConstants Element");
+
+ dxbc::RTS0::v1::RootParameterHeader Header;
+ // The parameter offset doesn't matter here - we recalculate it during
+ // serialization Header.ParameterOffset = 0;
+ Header.ParameterType =
+ llvm::to_underlying(dxbc::RootParameterType::Constants32Bit);
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1))
+ Header.ShaderVisibility = *Val;
+ else
+ return reportError(Ctx, "Invalid value for ShaderVisibility");
+
+ dxbc::RTS0::v1::RootConstants Constants;
+ if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2))
+ Constants.ShaderRegister = *Val;
+ else
+ return reportError(Ctx, "Invalid value for ShaderRegister");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3))
+ Constants.RegisterSpace = *Val;
+ else
+ return reportError(Ctx, "Invalid value for RegisterSpace");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4))
+ Constants.Num32BitValues = *Val;
+ else
+ return reportError(Ctx, "Invalid value for Num32BitValues");
+
+ RSD.ParametersContainer.addParameter(Header, Constants);
+
+ return false;
+}
+
+llvm::Error MetadataParser::parseRootDescriptors(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, MDNode *RootDescriptorNode, RootSignatureElementKind ElementKind){
+assert(ElementKind == RootSignatureElementKind::SRV ||
+ ElementKind == RootSignatureElementKind::UAV ||
+ ElementKind == RootSignatureElementKind::CBV &&
+ "parseRootDescriptors should only be called with RootDescriptor "
+ "element kind.");
+ if (RootDescriptorNode->getNumOperands() != 5)
+ return reportError(Ctx, "Invalid format for Root Descriptor Element");
+
+ dxbc::RTS0::v1::RootParameterHeader Header;
+ switch (ElementKind) {
+ case RootSignatureElementKind::SRV:
+ Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::SRV);
+ break;
+ case RootSignatureElementKind::UAV:
+ Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::UAV);
+ break;
+ case RootSignatureElementKind::CBV:
+ Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::CBV);
+ break;
+ default:
+ llvm_unreachable("invalid Root Descriptor kind");
+ break;
+ }
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 1))
+ Header.ShaderVisibility = *Val;
+ else
+ return reportError(Ctx, "Invalid value for ShaderVisibility");
+
+ dxbc::RTS0::v2::RootDescriptor Descriptor;
+ if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2))
+ Descriptor.ShaderRegister = *Val;
+ else
+ return reportError(Ctx, "Invalid value for ShaderRegister");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 3))
+ Descriptor.RegisterSpace = *Val;
+ else
+ return reportError(Ctx, "Invalid value for RegisterSpace");
+
+ if (RSD.Version == 1) {
+ RSD.ParametersContainer.addParameter(Header, Descriptor);
+ return false;
+ }
+ assert(RSD.Version > 1);
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 4))
+ Descriptor.Flags = *Val;
+ else
+ return reportError(Ctx, "Invalid value for Root Descriptor Flags");
+
+ RSD.ParametersContainer.addParameter(Header, Descriptor);
+ return false;
+}
+
+llvm::Error MetadataParser::parseDescriptorRange(LLVMContext *Ctx, mcdxbc::DescriptorTable &Table, MDNode *RangeDescriptorNode){
+if (RangeDescriptorNode->getNumOperands() != 6)
+ return reportError(Ctx, "Invalid format for Descriptor Range");
+
+ dxbc::RTS0::v2::DescriptorRange Range;
+
+ std::optional<StringRef> ElementText =
+ extractMdStringValue(RangeDescriptorNode, 0);
+
+ if (!ElementText.has_value())
+ return reportError(Ctx, "Descriptor Range, first element is not a string.");
+
+ Range.RangeType =
+ StringSwitch<uint32_t>(*ElementText)
+ .Case("CBV", llvm::to_underlying(dxbc::DescriptorRangeType::CBV))
+ .Case("SRV", llvm::to_underlying(dxbc::DescriptorRangeType::SRV))
+ .Case("UAV", llvm::to_underlying(dxbc::DescriptorRangeType::UAV))
+ .Case("Sampler",
+ llvm::to_underlying(dxbc::DescriptorRangeType::Sampler))
+ .Default(~0U);
+
+ if (Range.RangeType == ~0U)
+ return reportError(Ctx, "Invalid Descriptor Range type: " + *ElementText);
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 1))
+ Range.NumDescriptors = *Val;
+ else
+ return reportError(Ctx, "Invalid value for Number of Descriptor in Range");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 2))
+ Range.BaseShaderRegister = *Val;
+ else
+ return reportError(Ctx, "Invalid value for BaseShaderRegister");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 3))
+ Range.RegisterSpace = *Val;
+ else
+ return reportError(Ctx, "Invalid value for RegisterSpace");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 4))
+ Range.OffsetInDescriptorsFromTableStart = *Val;
+ else
+ return reportError(Ctx,
+ "Invalid value for OffsetInDescriptorsFromTableStart");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 5))
+ Range.Flags = *Val;
+ else
+ return reportError(Ctx, "Invalid value for Descriptor Range Flags");
+
+ Table.Ranges.push_back(Range);
+ return false;
+}
+
+llvm::Error MetadataParser::parseDescriptorTable(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, MDNode *DescriptorTableNode){
+const unsigned int NumOperands = DescriptorTableNode->getNumOperands();
+ if (NumOperands < 2)
+ return reportError(Ctx, "Invalid format for Descriptor Table");
+
+ dxbc::RTS0::v1::RootParameterHeader Header;
+ if (std::optional<uint32_t> Val = extractMdIntValue(DescriptorTableNode, 1))
+ Header.ShaderVisibility = *Val;
+ else
+ return reportError(Ctx, "Invalid value for ShaderVisibility");
+
+ mcdxbc::DescriptorTable Table;
+ Header.ParameterType =
+ llvm::to_underlying(dxbc::RootParameterType::DescriptorTable);
+
+ for (unsigned int I = 2; I < NumOperands; I++) {
+ MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I));
+ if (Element == nullptr)
+ return reportError(Ctx, "Missing Root Element Metadata Node.");
+
+ if (parseDescriptorRange(Ctx, Table, Element))
+ return true;
+ }
+
+ RSD.ParametersContainer.addParameter(Header, Table);
+ return false;
+}
+
+llvm::Error parseStaticSampler(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, MDNode *StaticSamplerNode) {
+ if (StaticSamplerNode->getNumOperands() != 14)
+ return reportError(Ctx, "Invalid format for Static Sampler");
+
+ dxbc::RTS0::v1::StaticSampler Sampler;
+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 1))
+ Sampler.Filter = *Val;
+ else
+ return reportError(Ctx, "Invalid value for Filter");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 2))
+ Sampler.AddressU = *Val;
+ else
+ return reportError(Ctx, "Invalid value for AddressU");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 3))
+ Sampler.AddressV = *Val;
+ else
+ return reportError(Ctx, "Invalid value for AddressV");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 4))
+ Sampler.AddressW = *Val;
+ else
+ return reportError(Ctx, "Invalid value for AddressW");
+
+ if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 5))
+ Sampler.MipLODBias = *Val;
+ else
+ return reportError(Ctx, "Invalid value for MipLODBias");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 6))
+ Sampler.MaxAnisotropy = *Val;
+ else
+ return reportError(Ctx, "Invalid value for MaxAnisotropy");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 7))
+ Sampler.ComparisonFunc = *Val;
+ else
+ return reportError(Ctx, "Invalid value for ComparisonFunc ");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 8))
+ Sampler.BorderColor = *Val;
+ else
+ return reportError(Ctx, "Invalid value for ComparisonFunc ");
+
+ if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 9))
+ Sampler.MinLOD = *Val;
+ else
+ return reportError(Ctx, "Invalid value for MinLOD");
+
+ if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 10))
+ Sampler.MaxLOD = *Val;
+ else
+ return reportError(Ctx, "Invalid value for MaxLOD");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 11))
+ Sampler.ShaderRegister = *Val;
+ else
+ return reportError(Ctx, "Invalid value for ShaderRegister");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 12))
+ Sampler.RegisterSpace = *Val;
+ else
+ return reportError(Ctx, "Invalid value for RegisterSpace");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 13))
+ Sampler.ShaderVisibility = *Val;
+ else
+ return reportError(Ctx, "Invalid value for ShaderVisibility");
+
+ RSD.StaticSamplers.push_back(Sampler);
+ return false;
+}
+
+llvm::Error MetadataParser::parseRootSignatureElement(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, MDNode *Element){
+std::optional<StringRef> ElementText = extractMdStringValue(Element, 0);
+ if (!ElementText.has_value())
+ return reportError(Ctx, "Invalid format for Root Element");
+
+ RootSignatureElementKind ElementKind =
+ StringSwitch<RootSignatureElementKind>(*ElementText)
+ .Case("RootFlags", RootSignatureElementKind::RootFlags)
+ .Case("RootConstants", RootSignatureElementKind::RootConstants)
+ .Case("RootCBV", RootSignatureElementKind::CBV)
+ .Case("RootSRV", RootSignatureElementKind::SRV)
+ .Case("RootUAV", RootSignatureElementKind::UAV)
+ .Case("DescriptorTable", RootSignatureElementKind::DescriptorTable)
+ .Case("StaticSampler", RootSignatureElementKind::StaticSamplers)
+ .Default(RootSignatureElementKind::Error);
+
+ switch (ElementKind) {
+
+ case RootSignatureElementKind::RootFlags:
+ return parseRootFlags(Ctx, RSD, Element);
+ case RootSignatureElementKind::RootConstants:
+ return parseRootConstants(Ctx, RSD, Element);
+ case RootSignatureElementKind::CBV:
+ case RootSignatureElementKind::SRV:
+ case RootSignatureElementKind::UAV:
+ return parseRootDescriptors(Ctx, RSD, Element, ElementKind);
+ case RootSignatureElementKind::DescriptorTable:
+ return parseDescriptorTable(Ctx, RSD, Element);
+ case RootSignatureElementKind::StaticSamplers:
+ return parseStaticSampler(Ctx, RSD, Element);
+ case RootSignatureElementKind::Error:
+ return reportError(Ctx, "Invalid Root Signature Element: " + *ElementText);
+ }
+
+ llvm_unreachable("Unhandled RootSignatureElementKind enum.");
+}
+
+
} // namespace rootsig
} // namespace hlsl
} // namespace llvm
diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.h b/llvm/lib/Target/DirectX/DXILRootSignature.h
index fc39b38258df8..76328bb15fa58 100644
--- a/llvm/lib/Target/DirectX/DXILRootSignature.h
+++ b/llvm/lib/Target/DirectX/DXILRootSignature.h
@@ -26,16 +26,7 @@
namespace llvm {
namespace dxil {
-enum class RootSignatureElementKind {
- Error = 0,
- RootFlags = 1,
- RootConstants = 2,
- SRV = 3,
- UAV = 4,
- CBV = 5,
- DescriptorTable = 6,
- StaticSamplers = 7
-};
+
class RootSignatureBindingInfo {
private:
>From 7436dfecafd1a48f2ec5a7d968b218a0a5e42ef8 Mon Sep 17 00:00:00 2001
From: joaosaffran <joao.saffran at microsoft.com>
Date: Wed, 16 Jul 2025 23:00:15 +0000
Subject: [PATCH 02/10] moving to Metadata lib
---
.../Frontend/HLSL/RootSignatureMetadata.h | 128 ++++-
.../llvm/MC/DXContainerRootSignature.h | 5 +
.../Frontend/HLSL/RootSignatureMetadata.cpp | 389 ++++++++++----
llvm/lib/Target/DirectX/DXILRootSignature.cpp | 489 +-----------------
llvm/lib/Target/DirectX/DXILRootSignature.h | 2 -
...ature-DescriptorTable-Invalid-RangeType.ll | 2 +-
.../RootSignature-Flags-Error.ll | 2 +-
...ure-RootDescriptor-Invalid-RegisterKind.ll | 2 +-
...Signature-StaticSamplers-Invalid-MaxLod.ll | 2 +-
...Signature-StaticSamplers-Invalid-MinLod.ll | 2 +-
...ature-StaticSamplers-Invalid-MinLopBias.ll | 2 +-
11 files changed, 428 insertions(+), 597 deletions(-)
diff --git a/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h b/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
index 5aa6c30491025..6f337660ee6c8 100644
--- a/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
+++ b/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
@@ -15,9 +15,11 @@
#define LLVM_FRONTEND_HLSL_ROOTSIGNATUREMETADATA_H
#include "llvm/Frontend/HLSL/HLSLRootSignature.h"
-#include "llvm/MC/DXContainerRootSignature.h"
+#include "llvm/IR/Constants.h"
#include "llvm/IR/Function.h"
+#include "llvm/MC/DXContainerRootSignature.h"
#include "llvm/Support/Error.h"
+#include <cstdint>
#include <unordered_map>
namespace llvm {
@@ -28,6 +30,96 @@ class Metadata;
namespace hlsl {
namespace rootsig {
+inline std::optional<uint32_t> extractMdIntValue(MDNode *Node,
+ unsigned int OpId) {
+ if (auto *CI =
+ mdconst::dyn_extract<ConstantInt>(Node->getOperand(OpId).get()))
+ return CI->getZExtValue();
+ return std::nullopt;
+}
+
+inline std::optional<float> extractMdFloatValue(MDNode *Node,
+ unsigned int OpId) {
+ if (auto *CI = mdconst::dyn_extract<ConstantFP>(Node->getOperand(OpId).get()))
+ return CI->getValueAPF().convertToFloat();
+ return std::nullopt;
+}
+
+inline std::optional<StringRef> extractMdStringValue(MDNode *Node,
+ unsigned int OpId) {
+ MDString *NodeText = dyn_cast<MDString>(Node->getOperand(OpId));
+ if (NodeText == nullptr)
+ return std::nullopt;
+ return NodeText->getString();
+}
+
+template <typename T>
+class RootSignatureValidationError
+ : public ErrorInfo<RootSignatureValidationError<T>> {
+public:
+ static char ID;
+ std::string ParamName;
+ T Value;
+
+ RootSignatureValidationError(StringRef ParamName, T Value)
+ : ParamName(ParamName.str()), Value(Value) {}
+
+ void log(raw_ostream &OS) const override {
+ OS << "Invalid value for " << ParamName << ": " << Value;
+ }
+
+ std::error_code convertToErrorCode() const override {
+ return llvm::inconvertibleErrorCode();
+ }
+};
+
+class GenericRSMetadataError : public ErrorInfo<GenericRSMetadataError> {
+public:
+ static char ID;
+ std::string Message;
+
+ GenericRSMetadataError(Twine Message) : Message(Message.str()) {}
+
+ void log(raw_ostream &OS) const override { OS << Message; }
+
+ std::error_code convertToErrorCode() const override {
+ return llvm::inconvertibleErrorCode();
+ }
+};
+
+class InvalidRSMetadataFormat : public ErrorInfo<InvalidRSMetadataFormat> {
+public:
+ static char ID;
+ std::string ElementName;
+
+ InvalidRSMetadataFormat(StringRef ElementName)
+ : ElementName(ElementName.str()) {}
+
+ void log(raw_ostream &OS) const override {
+ OS << "Invalid format for " << ElementName;
+ }
+
+ std::error_code convertToErrorCode() const override {
+ return llvm::inconvertibleErrorCode();
+ }
+};
+
+class InvalidRSMetadataValue : public ErrorInfo<InvalidRSMetadataValue> {
+public:
+ static char ID;
+ std::string ParamName;
+
+ InvalidRSMetadataValue(StringRef ParamName) : ParamName(ParamName.str()) {}
+
+ void log(raw_ostream &OS) const override {
+ OS << "Invalid value for " << ParamName;
+ }
+
+ std::error_code convertToErrorCode() const override {
+ return llvm::inconvertibleErrorCode();
+ }
+};
+
class MetadataBuilder {
public:
MetadataBuilder(llvm::LLVMContext &Ctx, ArrayRef<RootElement> Elements)
@@ -66,22 +158,32 @@ enum class RootSignatureElementKind {
class MetadataParser {
public:
- using MapT = SmallDenseMap<const Function *, llvm::mcdxbc::RootSignatureDesc>;
- MetadataParser(llvm::LLVMContext &Ctx, MDNode* Root): Ctx(Ctx), Root(Root) {}
+ MetadataParser(MDNode *Root) : Root(Root) {}
/// Iterates through root signature and converts them into MapT
- LLVM_ABI llvm::Expected<MapT*> ParseRootSignature();
+ LLVM_ABI llvm::Expected<llvm::mcdxbc::RootSignatureDesc>
+ ParseRootSignature(uint32_t Version);
private:
- llvm::Error parseRootFlags(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, MDNode *RootFlagNode);
- llvm::Error parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, MDNode *RootConstantNode);
- llvm::Error parseRootDescriptors(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, MDNode *RootDescriptorNode, RootSignatureElementKind ElementKind);
- llvm::Error parseDescriptorRange(LLVMContext *Ctx, mcdxbc::DescriptorTable &Table, MDNode *RangeDescriptorNode);
- llvm::Error parseDescriptorTable(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, MDNode *DescriptorTableNode);
- llvm::Error parseRootSignatureElement(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, MDNode *Element);
- llvm::Error parseStaticSampler(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, MDNode *StaticSamplerNode);
- llvm::LLVMContext &Ctx;
- MDNode* Root;
+ llvm::Error parseRootFlags(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *RootFlagNode);
+ llvm::Error parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *RootConstantNode);
+ llvm::Error parseRootDescriptors(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *RootDescriptorNode,
+ RootSignatureElementKind ElementKind);
+ llvm::Error parseDescriptorRange(mcdxbc::DescriptorTable &Table,
+ MDNode *RangeDescriptorNode);
+ llvm::Error parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *DescriptorTableNode);
+ llvm::Error parseRootSignatureElement(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *Element);
+ llvm::Error parseStaticSampler(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *StaticSamplerNode);
+
+ llvm::Error validateRootSignature(const llvm::mcdxbc::RootSignatureDesc &RSD);
+
+ MDNode *Root;
};
} // namespace rootsig
diff --git a/llvm/include/llvm/MC/DXContainerRootSignature.h b/llvm/include/llvm/MC/DXContainerRootSignature.h
index 4b6b42f7d74f7..14a2429ffcc78 100644
--- a/llvm/include/llvm/MC/DXContainerRootSignature.h
+++ b/llvm/include/llvm/MC/DXContainerRootSignature.h
@@ -6,6 +6,9 @@
//
//===----------------------------------------------------------------------===//
+#ifndef LLVM_MC_DXCONTAINERROOTSIGNATURE_H
+#define LLVM_MC_DXCONTAINERROOTSIGNATURE_H
+
#include "llvm/BinaryFormat/DXContainer.h"
#include <cstdint>
#include <limits>
@@ -116,3 +119,5 @@ struct RootSignatureDesc {
};
} // namespace mcdxbc
} // namespace llvm
+
+#endif // LLVM_MC_DXCONTAINERROOTSIGNATURE_H
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
index 2ae305dfb0f86..89e130db796db 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
@@ -12,16 +12,25 @@
//===----------------------------------------------------------------------===//
#include "llvm/Frontend/HLSL/RootSignatureMetadata.h"
+#include "llvm/Frontend/HLSL/RootSignatureValidations.h"
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Metadata.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/ScopedPrinter.h"
+#include <cstdint>
+#include <utility>
namespace llvm {
namespace hlsl {
namespace rootsig {
+char GenericRSMetadataError::ID;
+char InvalidRSMetadataFormat::ID;
+char InvalidRSMetadataValue::ID;
+
+template <typename T> char RootSignatureValidationError<T>::ID;
+
static const EnumEntry<dxil::ResourceClass> ResourceClassNames[] = {
{"CBV", dxil::ResourceClass::CBuffer},
{"SRV", dxil::ResourceClass::SRV},
@@ -36,42 +45,6 @@ static std::optional<StringRef> getResourceName(dxil::ResourceClass Class) {
return std::nullopt;
}
-static bool reportError(LLVMContext *Ctx, Twine Message,
- DiagnosticSeverity Severity = DS_Error) {
- Ctx->diagnose(DiagnosticInfoGeneric(Message, Severity));
- return true;
-}
-
-static bool reportValueError(LLVMContext *Ctx, Twine ParamName,
- uint32_t Value) {
- Ctx->diagnose(DiagnosticInfoGeneric(
- "Invalid value for " + ParamName + ": " + Twine(Value), DS_Error));
- return true;
-}
-
-static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
- unsigned int OpId) {
- if (auto *CI =
- mdconst::dyn_extract<ConstantInt>(Node->getOperand(OpId).get()))
- return CI->getZExtValue();
- return std::nullopt;
-}
-
-static std::optional<float> extractMdFloatValue(MDNode *Node,
- unsigned int OpId) {
- if (auto *CI = mdconst::dyn_extract<ConstantFP>(Node->getOperand(OpId).get()))
- return CI->getValueAPF().convertToFloat();
- return std::nullopt;
-}
-
-static std::optional<StringRef> extractMdStringValue(MDNode *Node,
- unsigned int OpId) {
- MDString *NodeText = dyn_cast<MDString>(Node->getOperand(OpId));
- if (NodeText == nullptr)
- return std::nullopt;
- return NodeText->getString();
-}
-
namespace {
// We use the OverloadVisit with std::visit to ensure the compiler catches if a
@@ -227,21 +200,23 @@ MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) {
return MDNode::get(Ctx, Operands);
}
-llvm::Error MetadataParser::parseRootFlags(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, MDNode *RootFlagNode){
+llvm::Error MetadataParser::parseRootFlags(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *RootFlagNode) {
if (RootFlagNode->getNumOperands() != 2)
- return reportError(Ctx, "Invalid format for RootFlag Element");
+ return make_error<InvalidRSMetadataFormat>("RootFlag Element");
if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1))
RSD.Flags = *Val;
else
- return reportError(Ctx, "Invalid value for RootFlag");
+ return make_error<InvalidRSMetadataValue>("RootFlag");
- return false;
+ return llvm::Error::success();
}
-llvm::Error MetadataParser::parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, MDNode *RootConstantNode){
-if (RootConstantNode->getNumOperands() != 5)
- return reportError(Ctx, "Invalid format for RootConstants Element");
+llvm::Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *RootConstantNode) {
+ if (RootConstantNode->getNumOperands() != 5)
+ return make_error<InvalidRSMetadataFormat>("RootConstants Element");
dxbc::RTS0::v1::RootParameterHeader Header;
// The parameter offset doesn't matter here - we recalculate it during
@@ -252,37 +227,40 @@ if (RootConstantNode->getNumOperands() != 5)
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1))
Header.ShaderVisibility = *Val;
else
- return reportError(Ctx, "Invalid value for ShaderVisibility");
+ return make_error<InvalidRSMetadataValue>("ShaderVisibility");
dxbc::RTS0::v1::RootConstants Constants;
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2))
Constants.ShaderRegister = *Val;
else
- return reportError(Ctx, "Invalid value for ShaderRegister");
+ return make_error<InvalidRSMetadataValue>("ShaderRegister");
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3))
Constants.RegisterSpace = *Val;
else
- return reportError(Ctx, "Invalid value for RegisterSpace");
+ return make_error<InvalidRSMetadataValue>("RegisterSpace");
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4))
Constants.Num32BitValues = *Val;
else
- return reportError(Ctx, "Invalid value for Num32BitValues");
+ return make_error<InvalidRSMetadataValue>("Num32BitValues");
RSD.ParametersContainer.addParameter(Header, Constants);
- return false;
+ return llvm::Error::success();
}
-llvm::Error MetadataParser::parseRootDescriptors(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, MDNode *RootDescriptorNode, RootSignatureElementKind ElementKind){
-assert(ElementKind == RootSignatureElementKind::SRV ||
+llvm::Error
+MetadataParser::parseRootDescriptors(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *RootDescriptorNode,
+ RootSignatureElementKind ElementKind) {
+ assert(ElementKind == RootSignatureElementKind::SRV ||
ElementKind == RootSignatureElementKind::UAV ||
ElementKind == RootSignatureElementKind::CBV &&
- "parseRootDescriptors should only be called with RootDescriptor "
+ "parseRootDescriptors should only be called with RootDescriptor"
"element kind.");
if (RootDescriptorNode->getNumOperands() != 5)
- return reportError(Ctx, "Invalid format for Root Descriptor Element");
+ return make_error<InvalidRSMetadataFormat>("Root Descriptor Element");
dxbc::RTS0::v1::RootParameterHeader Header;
switch (ElementKind) {
@@ -303,37 +281,38 @@ assert(ElementKind == RootSignatureElementKind::SRV ||
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 1))
Header.ShaderVisibility = *Val;
else
- return reportError(Ctx, "Invalid value for ShaderVisibility");
+ return make_error<InvalidRSMetadataValue>("ShaderVisibility");
dxbc::RTS0::v2::RootDescriptor Descriptor;
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2))
Descriptor.ShaderRegister = *Val;
else
- return reportError(Ctx, "Invalid value for ShaderRegister");
+ return make_error<InvalidRSMetadataValue>("ShaderRegister");
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 3))
Descriptor.RegisterSpace = *Val;
else
- return reportError(Ctx, "Invalid value for RegisterSpace");
+ return make_error<InvalidRSMetadataValue>("RegisterSpace");
if (RSD.Version == 1) {
RSD.ParametersContainer.addParameter(Header, Descriptor);
- return false;
+ return llvm::Error::success();
}
assert(RSD.Version > 1);
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 4))
Descriptor.Flags = *Val;
else
- return reportError(Ctx, "Invalid value for Root Descriptor Flags");
+ return make_error<InvalidRSMetadataValue>("Root Descriptor Flags");
RSD.ParametersContainer.addParameter(Header, Descriptor);
- return false;
+ return llvm::Error::success();
}
-llvm::Error MetadataParser::parseDescriptorRange(LLVMContext *Ctx, mcdxbc::DescriptorTable &Table, MDNode *RangeDescriptorNode){
-if (RangeDescriptorNode->getNumOperands() != 6)
- return reportError(Ctx, "Invalid format for Descriptor Range");
+llvm::Error MetadataParser::parseDescriptorRange(mcdxbc::DescriptorTable &Table,
+ MDNode *RangeDescriptorNode) {
+ if (RangeDescriptorNode->getNumOperands() != 6)
+ return make_error<InvalidRSMetadataFormat>("Descriptor Range");
dxbc::RTS0::v2::DescriptorRange Range;
@@ -341,7 +320,7 @@ if (RangeDescriptorNode->getNumOperands() != 6)
extractMdStringValue(RangeDescriptorNode, 0);
if (!ElementText.has_value())
- return reportError(Ctx, "Descriptor Range, first element is not a string.");
+ return make_error<InvalidRSMetadataFormat>("Descriptor Range");
Range.RangeType =
StringSwitch<uint32_t>(*ElementText)
@@ -353,48 +332,50 @@ if (RangeDescriptorNode->getNumOperands() != 6)
.Default(~0U);
if (Range.RangeType == ~0U)
- return reportError(Ctx, "Invalid Descriptor Range type: " + *ElementText);
+ return make_error<GenericRSMetadataError>("Invalid Descriptor Range type:" +
+ *ElementText);
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 1))
Range.NumDescriptors = *Val;
else
- return reportError(Ctx, "Invalid value for Number of Descriptor in Range");
+ return make_error<GenericRSMetadataError>("Number of Descriptor in Range");
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 2))
Range.BaseShaderRegister = *Val;
else
- return reportError(Ctx, "Invalid value for BaseShaderRegister");
+ return make_error<InvalidRSMetadataValue>("BaseShaderRegister");
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 3))
Range.RegisterSpace = *Val;
else
- return reportError(Ctx, "Invalid value for RegisterSpace");
+ return make_error<InvalidRSMetadataValue>("RegisterSpace");
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 4))
Range.OffsetInDescriptorsFromTableStart = *Val;
else
- return reportError(Ctx,
- "Invalid value for OffsetInDescriptorsFromTableStart");
+ return make_error<InvalidRSMetadataValue>(
+ "OffsetInDescriptorsFromTableStart");
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 5))
Range.Flags = *Val;
else
- return reportError(Ctx, "Invalid value for Descriptor Range Flags");
+ return make_error<InvalidRSMetadataValue>("Descriptor Range Flags");
Table.Ranges.push_back(Range);
- return false;
+ return llvm::Error::success();
}
-llvm::Error MetadataParser::parseDescriptorTable(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, MDNode *DescriptorTableNode){
-const unsigned int NumOperands = DescriptorTableNode->getNumOperands();
+llvm::Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *DescriptorTableNode) {
+ const unsigned int NumOperands = DescriptorTableNode->getNumOperands();
if (NumOperands < 2)
- return reportError(Ctx, "Invalid format for Descriptor Table");
+ return make_error<InvalidRSMetadataFormat>("Descriptor Table");
dxbc::RTS0::v1::RootParameterHeader Header;
if (std::optional<uint32_t> Val = extractMdIntValue(DescriptorTableNode, 1))
Header.ShaderVisibility = *Val;
else
- return reportError(Ctx, "Invalid value for ShaderVisibility");
+ return make_error<InvalidRSMetadataValue>("ShaderVisibility");
mcdxbc::DescriptorTable Table;
Header.ParameterType =
@@ -403,94 +384,98 @@ const unsigned int NumOperands = DescriptorTableNode->getNumOperands();
for (unsigned int I = 2; I < NumOperands; I++) {
MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I));
if (Element == nullptr)
- return reportError(Ctx, "Missing Root Element Metadata Node.");
+ return make_error<GenericRSMetadataError>(
+ "Missing Root Element Metadata Node.");
- if (parseDescriptorRange(Ctx, Table, Element))
- return true;
+ if (auto Err = parseDescriptorRange(Table, Element))
+ return Err;
}
RSD.ParametersContainer.addParameter(Header, Table);
- return false;
+ return llvm::Error::success();
}
-llvm::Error parseStaticSampler(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, MDNode *StaticSamplerNode) {
+llvm::Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *StaticSamplerNode) {
if (StaticSamplerNode->getNumOperands() != 14)
- return reportError(Ctx, "Invalid format for Static Sampler");
+ return make_error<InvalidRSMetadataFormat>("Static Sampler");
dxbc::RTS0::v1::StaticSampler Sampler;
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 1))
Sampler.Filter = *Val;
else
- return reportError(Ctx, "Invalid value for Filter");
+ return make_error<InvalidRSMetadataValue>("Filter");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 2))
Sampler.AddressU = *Val;
else
- return reportError(Ctx, "Invalid value for AddressU");
+ return make_error<InvalidRSMetadataValue>("AddressU");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 3))
Sampler.AddressV = *Val;
else
- return reportError(Ctx, "Invalid value for AddressV");
+ return make_error<InvalidRSMetadataValue>("AddressV");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 4))
Sampler.AddressW = *Val;
else
- return reportError(Ctx, "Invalid value for AddressW");
+ return make_error<InvalidRSMetadataValue>("AddressW");
if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 5))
Sampler.MipLODBias = *Val;
else
- return reportError(Ctx, "Invalid value for MipLODBias");
+ return make_error<InvalidRSMetadataValue>("MipLODBias");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 6))
Sampler.MaxAnisotropy = *Val;
else
- return reportError(Ctx, "Invalid value for MaxAnisotropy");
+ return make_error<InvalidRSMetadataValue>("MaxAnisotropy");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 7))
Sampler.ComparisonFunc = *Val;
else
- return reportError(Ctx, "Invalid value for ComparisonFunc ");
+ return make_error<InvalidRSMetadataValue>("ComparisonFunc");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 8))
Sampler.BorderColor = *Val;
else
- return reportError(Ctx, "Invalid value for ComparisonFunc ");
+ return make_error<InvalidRSMetadataValue>("ComparisonFunc");
if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 9))
Sampler.MinLOD = *Val;
else
- return reportError(Ctx, "Invalid value for MinLOD");
+ return make_error<InvalidRSMetadataValue>("MinLOD");
if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 10))
Sampler.MaxLOD = *Val;
else
- return reportError(Ctx, "Invalid value for MaxLOD");
+ return make_error<InvalidRSMetadataValue>("MaxLOD");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 11))
Sampler.ShaderRegister = *Val;
else
- return reportError(Ctx, "Invalid value for ShaderRegister");
+ return make_error<InvalidRSMetadataValue>("ShaderRegister");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 12))
Sampler.RegisterSpace = *Val;
else
- return reportError(Ctx, "Invalid value for RegisterSpace");
+ return make_error<InvalidRSMetadataValue>("RegisterSpace");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 13))
Sampler.ShaderVisibility = *Val;
else
- return reportError(Ctx, "Invalid value for ShaderVisibility");
+ return make_error<InvalidRSMetadataValue>("ShaderVisibility");
RSD.StaticSamplers.push_back(Sampler);
- return false;
+ return llvm::Error::success();
}
-llvm::Error MetadataParser::parseRootSignatureElement(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, MDNode *Element){
-std::optional<StringRef> ElementText = extractMdStringValue(Element, 0);
+llvm::Error
+MetadataParser::parseRootSignatureElement(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *Element) {
+ std::optional<StringRef> ElementText = extractMdStringValue(Element, 0);
if (!ElementText.has_value())
- return reportError(Ctx, "Invalid format for Root Element");
+ return make_error<InvalidRSMetadataFormat>("Root Element");
RootSignatureElementKind ElementKind =
StringSwitch<RootSignatureElementKind>(*ElementText)
@@ -506,25 +491,223 @@ std::optional<StringRef> ElementText = extractMdStringValue(Element, 0);
switch (ElementKind) {
case RootSignatureElementKind::RootFlags:
- return parseRootFlags(Ctx, RSD, Element);
+ return parseRootFlags(RSD, Element);
case RootSignatureElementKind::RootConstants:
- return parseRootConstants(Ctx, RSD, Element);
+ return parseRootConstants(RSD, Element);
case RootSignatureElementKind::CBV:
case RootSignatureElementKind::SRV:
case RootSignatureElementKind::UAV:
- return parseRootDescriptors(Ctx, RSD, Element, ElementKind);
+ return parseRootDescriptors(RSD, Element, ElementKind);
case RootSignatureElementKind::DescriptorTable:
- return parseDescriptorTable(Ctx, RSD, Element);
+ return parseDescriptorTable(RSD, Element);
case RootSignatureElementKind::StaticSamplers:
- return parseStaticSampler(Ctx, RSD, Element);
+ return parseStaticSampler(RSD, Element);
case RootSignatureElementKind::Error:
- return reportError(Ctx, "Invalid Root Signature Element: " + *ElementText);
+ return make_error<GenericRSMetadataError>(
+ "Invalid Root Signature Element:" + *ElementText);
}
llvm_unreachable("Unhandled RootSignatureElementKind enum.");
}
+llvm::Error MetadataParser::validateRootSignature(
+ const llvm::mcdxbc::RootSignatureDesc &RSD) {
+ Error DeferredErrs = Error::success();
+ if (!llvm::hlsl::rootsig::verifyVersion(RSD.Version)) {
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "Version", RSD.Version));
+ }
+
+ if (!llvm::hlsl::rootsig::verifyRootFlag(RSD.Flags)) {
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "RootFlags", RSD.Flags));
+ }
+
+ for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) {
+ if (!dxbc::isValidShaderVisibility(Info.Header.ShaderVisibility))
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "ShaderVisibility", Info.Header.ShaderVisibility));
+
+ assert(dxbc::isValidParameterType(Info.Header.ParameterType) &&
+ "Invalid value for ParameterType");
+
+ switch (Info.Header.ParameterType) {
+
+ case llvm::to_underlying(dxbc::RootParameterType::CBV):
+ case llvm::to_underlying(dxbc::RootParameterType::UAV):
+ case llvm::to_underlying(dxbc::RootParameterType::SRV): {
+ const dxbc::RTS0::v2::RootDescriptor &Descriptor =
+ RSD.ParametersContainer.getRootDescriptor(Info.Location);
+ if (!llvm::hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister))
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "ShaderRegister", Descriptor.ShaderRegister));
+
+ if (!llvm::hlsl::rootsig::verifyRegisterSpace(Descriptor.RegisterSpace))
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "RegisterSpace", Descriptor.RegisterSpace));
+
+ if (RSD.Version > 1) {
+ if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(RSD.Version,
+ Descriptor.Flags))
+ DeferredErrs = joinErrors(
+ std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "RootDescriptorFlag", Descriptor.Flags));
+ }
+ break;
+ }
+ case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
+ const mcdxbc::DescriptorTable &Table =
+ RSD.ParametersContainer.getDescriptorTable(Info.Location);
+ for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) {
+ if (!llvm::hlsl::rootsig::verifyRangeType(Range.RangeType))
+ DeferredErrs = joinErrors(
+ std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "RangeType", Range.RangeType));
+
+ if (!llvm::hlsl::rootsig::verifyRegisterSpace(Range.RegisterSpace))
+ DeferredErrs = joinErrors(
+ std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "RegisterSpace", Range.RegisterSpace));
+
+ if (!llvm::hlsl::rootsig::verifyNumDescriptors(Range.NumDescriptors))
+ DeferredErrs = joinErrors(
+ std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "NumDescriptors", Range.NumDescriptors));
+
+ if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(
+ RSD.Version, Range.RangeType, Range.Flags))
+ DeferredErrs = joinErrors(
+ std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "DescriptorFlag", Range.Flags));
+ }
+ break;
+ }
+ }
+ }
+
+ for (const dxbc::RTS0::v1::StaticSampler &Sampler : RSD.StaticSamplers) {
+ if (!llvm::hlsl::rootsig::verifySamplerFilter(Sampler.Filter))
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "Filter", Sampler.Filter));
+
+ if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressU))
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "AddressU", Sampler.AddressU));
+
+ if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressV))
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "AddressV", Sampler.AddressV));
+
+ if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressW))
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "AddressW", Sampler.AddressW));
+
+ if (!llvm::hlsl::rootsig::verifyMipLODBias(Sampler.MipLODBias))
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<float>>(
+ "MipLODBias", Sampler.MipLODBias));
+
+ if (!llvm::hlsl::rootsig::verifyMaxAnisotropy(Sampler.MaxAnisotropy))
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "MaxAnisotropy", Sampler.MaxAnisotropy));
+
+ if (!llvm::hlsl::rootsig::verifyComparisonFunc(Sampler.ComparisonFunc))
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "ComparisonFunc", Sampler.ComparisonFunc));
+
+ if (!llvm::hlsl::rootsig::verifyBorderColor(Sampler.BorderColor))
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "BorderColor", Sampler.BorderColor));
+
+ if (!llvm::hlsl::rootsig::verifyLOD(Sampler.MinLOD))
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<float>>(
+ "MinLOD", Sampler.MinLOD));
+
+ if (!llvm::hlsl::rootsig::verifyLOD(Sampler.MaxLOD))
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<float>>(
+ "MaxLOD", Sampler.MaxLOD));
+
+ if (!llvm::hlsl::rootsig::verifyRegisterValue(Sampler.ShaderRegister))
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "ShaderRegister", Sampler.ShaderRegister));
+
+ if (!llvm::hlsl::rootsig::verifyRegisterSpace(Sampler.RegisterSpace))
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "RegisterSpace", Sampler.RegisterSpace));
+
+ if (!dxbc::isValidShaderVisibility(Sampler.ShaderVisibility))
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "ShaderVisibility", Sampler.ShaderVisibility));
+ }
+
+ return DeferredErrs;
+}
+
+llvm::Expected<mcdxbc::RootSignatureDesc>
+MetadataParser::ParseRootSignature(uint32_t Version) {
+ Error DeferredErrs = Error::success();
+ mcdxbc::RootSignatureDesc RSD;
+ RSD.Version = Version;
+ for (const auto &Operand : Root->operands()) {
+ MDNode *Element = dyn_cast<MDNode>(Operand);
+ if (Element == nullptr)
+ return joinErrors(std::move(DeferredErrs),
+ make_error<GenericRSMetadataError>(
+ "Missing Root Element Metadata Node."));
+
+ if (auto Err = parseRootSignatureElement(RSD, Element)) {
+ DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err));
+ }
+ }
+ if (auto Err = validateRootSignature(RSD))
+ DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err));
+
+ if (DeferredErrs)
+ return std::move(DeferredErrs);
+
+ return std::move(RSD);
+}
} // namespace rootsig
} // namespace hlsl
} // namespace llvm
diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
index dfc81626da01f..6459294169fe3 100644
--- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp
+++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
@@ -16,6 +16,7 @@
#include "llvm/ADT/Twine.h"
#include "llvm/Analysis/DXILMetadataAnalysis.h"
#include "llvm/BinaryFormat/DXContainer.h"
+#include "llvm/Frontend/HLSL/RootSignatureMetadata.h"
#include "llvm/Frontend/HLSL/RootSignatureValidations.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DiagnosticInfo.h"
@@ -29,8 +30,6 @@
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include <cstdint>
-#include <optional>
-#include <utility>
using namespace llvm;
using namespace llvm::dxil;
@@ -41,470 +40,6 @@ static bool reportError(LLVMContext *Ctx, Twine Message,
return true;
}
-static bool reportValueError(LLVMContext *Ctx, Twine ParamName,
- uint32_t Value) {
- Ctx->diagnose(DiagnosticInfoGeneric(
- "Invalid value for " + ParamName + ": " + Twine(Value), DS_Error));
- return true;
-}
-
-static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
- unsigned int OpId) {
- if (auto *CI =
- mdconst::dyn_extract<ConstantInt>(Node->getOperand(OpId).get()))
- return CI->getZExtValue();
- return std::nullopt;
-}
-
-static std::optional<float> extractMdFloatValue(MDNode *Node,
- unsigned int OpId) {
- if (auto *CI = mdconst::dyn_extract<ConstantFP>(Node->getOperand(OpId).get()))
- return CI->getValueAPF().convertToFloat();
- return std::nullopt;
-}
-
-static std::optional<StringRef> extractMdStringValue(MDNode *Node,
- unsigned int OpId) {
- MDString *NodeText = dyn_cast<MDString>(Node->getOperand(OpId));
- if (NodeText == nullptr)
- return std::nullopt;
- return NodeText->getString();
-}
-
-static bool parseRootFlags(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
- MDNode *RootFlagNode) {
-
- if (RootFlagNode->getNumOperands() != 2)
- return reportError(Ctx, "Invalid format for RootFlag Element");
-
- if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1))
- RSD.Flags = *Val;
- else
- return reportError(Ctx, "Invalid value for RootFlag");
-
- return false;
-}
-
-static bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
- MDNode *RootConstantNode) {
-
- if (RootConstantNode->getNumOperands() != 5)
- return reportError(Ctx, "Invalid format for RootConstants Element");
-
- dxbc::RTS0::v1::RootParameterHeader Header;
- // The parameter offset doesn't matter here - we recalculate it during
- // serialization Header.ParameterOffset = 0;
- Header.ParameterType =
- llvm::to_underlying(dxbc::RootParameterType::Constants32Bit);
-
- if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1))
- Header.ShaderVisibility = *Val;
- else
- return reportError(Ctx, "Invalid value for ShaderVisibility");
-
- dxbc::RTS0::v1::RootConstants Constants;
- if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2))
- Constants.ShaderRegister = *Val;
- else
- return reportError(Ctx, "Invalid value for ShaderRegister");
-
- if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3))
- Constants.RegisterSpace = *Val;
- else
- return reportError(Ctx, "Invalid value for RegisterSpace");
-
- if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4))
- Constants.Num32BitValues = *Val;
- else
- return reportError(Ctx, "Invalid value for Num32BitValues");
-
- RSD.ParametersContainer.addParameter(Header, Constants);
-
- return false;
-}
-
-static bool parseRootDescriptors(LLVMContext *Ctx,
- mcdxbc::RootSignatureDesc &RSD,
- MDNode *RootDescriptorNode,
- RootSignatureElementKind ElementKind) {
- assert(ElementKind == RootSignatureElementKind::SRV ||
- ElementKind == RootSignatureElementKind::UAV ||
- ElementKind == RootSignatureElementKind::CBV &&
- "parseRootDescriptors should only be called with RootDescriptor "
- "element kind.");
- if (RootDescriptorNode->getNumOperands() != 5)
- return reportError(Ctx, "Invalid format for Root Descriptor Element");
-
- dxbc::RTS0::v1::RootParameterHeader Header;
- switch (ElementKind) {
- case RootSignatureElementKind::SRV:
- Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::SRV);
- break;
- case RootSignatureElementKind::UAV:
- Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::UAV);
- break;
- case RootSignatureElementKind::CBV:
- Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::CBV);
- break;
- default:
- llvm_unreachable("invalid Root Descriptor kind");
- break;
- }
-
- if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 1))
- Header.ShaderVisibility = *Val;
- else
- return reportError(Ctx, "Invalid value for ShaderVisibility");
-
- dxbc::RTS0::v2::RootDescriptor Descriptor;
- if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2))
- Descriptor.ShaderRegister = *Val;
- else
- return reportError(Ctx, "Invalid value for ShaderRegister");
-
- if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 3))
- Descriptor.RegisterSpace = *Val;
- else
- return reportError(Ctx, "Invalid value for RegisterSpace");
-
- if (RSD.Version == 1) {
- RSD.ParametersContainer.addParameter(Header, Descriptor);
- return false;
- }
- assert(RSD.Version > 1);
-
- if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 4))
- Descriptor.Flags = *Val;
- else
- return reportError(Ctx, "Invalid value for Root Descriptor Flags");
-
- RSD.ParametersContainer.addParameter(Header, Descriptor);
- return false;
-}
-
-static bool parseDescriptorRange(LLVMContext *Ctx,
- mcdxbc::DescriptorTable &Table,
- MDNode *RangeDescriptorNode) {
-
- if (RangeDescriptorNode->getNumOperands() != 6)
- return reportError(Ctx, "Invalid format for Descriptor Range");
-
- dxbc::RTS0::v2::DescriptorRange Range;
-
- std::optional<StringRef> ElementText =
- extractMdStringValue(RangeDescriptorNode, 0);
-
- if (!ElementText.has_value())
- return reportError(Ctx, "Descriptor Range, first element is not a string.");
-
- Range.RangeType =
- StringSwitch<uint32_t>(*ElementText)
- .Case("CBV", llvm::to_underlying(dxbc::DescriptorRangeType::CBV))
- .Case("SRV", llvm::to_underlying(dxbc::DescriptorRangeType::SRV))
- .Case("UAV", llvm::to_underlying(dxbc::DescriptorRangeType::UAV))
- .Case("Sampler",
- llvm::to_underlying(dxbc::DescriptorRangeType::Sampler))
- .Default(~0U);
-
- if (Range.RangeType == ~0U)
- return reportError(Ctx, "Invalid Descriptor Range type: " + *ElementText);
-
- if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 1))
- Range.NumDescriptors = *Val;
- else
- return reportError(Ctx, "Invalid value for Number of Descriptor in Range");
-
- if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 2))
- Range.BaseShaderRegister = *Val;
- else
- return reportError(Ctx, "Invalid value for BaseShaderRegister");
-
- if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 3))
- Range.RegisterSpace = *Val;
- else
- return reportError(Ctx, "Invalid value for RegisterSpace");
-
- if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 4))
- Range.OffsetInDescriptorsFromTableStart = *Val;
- else
- return reportError(Ctx,
- "Invalid value for OffsetInDescriptorsFromTableStart");
-
- if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 5))
- Range.Flags = *Val;
- else
- return reportError(Ctx, "Invalid value for Descriptor Range Flags");
-
- Table.Ranges.push_back(Range);
- return false;
-}
-
-static bool parseDescriptorTable(LLVMContext *Ctx,
- mcdxbc::RootSignatureDesc &RSD,
- MDNode *DescriptorTableNode) {
- const unsigned int NumOperands = DescriptorTableNode->getNumOperands();
- if (NumOperands < 2)
- return reportError(Ctx, "Invalid format for Descriptor Table");
-
- dxbc::RTS0::v1::RootParameterHeader Header;
- if (std::optional<uint32_t> Val = extractMdIntValue(DescriptorTableNode, 1))
- Header.ShaderVisibility = *Val;
- else
- return reportError(Ctx, "Invalid value for ShaderVisibility");
-
- mcdxbc::DescriptorTable Table;
- Header.ParameterType =
- llvm::to_underlying(dxbc::RootParameterType::DescriptorTable);
-
- for (unsigned int I = 2; I < NumOperands; I++) {
- MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I));
- if (Element == nullptr)
- return reportError(Ctx, "Missing Root Element Metadata Node.");
-
- if (parseDescriptorRange(Ctx, Table, Element))
- return true;
- }
-
- RSD.ParametersContainer.addParameter(Header, Table);
- return false;
-}
-
-static bool parseStaticSampler(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
- MDNode *StaticSamplerNode) {
- if (StaticSamplerNode->getNumOperands() != 14)
- return reportError(Ctx, "Invalid format for Static Sampler");
-
- dxbc::RTS0::v1::StaticSampler Sampler;
- if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 1))
- Sampler.Filter = *Val;
- else
- return reportError(Ctx, "Invalid value for Filter");
-
- if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 2))
- Sampler.AddressU = *Val;
- else
- return reportError(Ctx, "Invalid value for AddressU");
-
- if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 3))
- Sampler.AddressV = *Val;
- else
- return reportError(Ctx, "Invalid value for AddressV");
-
- if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 4))
- Sampler.AddressW = *Val;
- else
- return reportError(Ctx, "Invalid value for AddressW");
-
- if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 5))
- Sampler.MipLODBias = *Val;
- else
- return reportError(Ctx, "Invalid value for MipLODBias");
-
- if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 6))
- Sampler.MaxAnisotropy = *Val;
- else
- return reportError(Ctx, "Invalid value for MaxAnisotropy");
-
- if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 7))
- Sampler.ComparisonFunc = *Val;
- else
- return reportError(Ctx, "Invalid value for ComparisonFunc ");
-
- if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 8))
- Sampler.BorderColor = *Val;
- else
- return reportError(Ctx, "Invalid value for ComparisonFunc ");
-
- if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 9))
- Sampler.MinLOD = *Val;
- else
- return reportError(Ctx, "Invalid value for MinLOD");
-
- if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 10))
- Sampler.MaxLOD = *Val;
- else
- return reportError(Ctx, "Invalid value for MaxLOD");
-
- if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 11))
- Sampler.ShaderRegister = *Val;
- else
- return reportError(Ctx, "Invalid value for ShaderRegister");
-
- if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 12))
- Sampler.RegisterSpace = *Val;
- else
- return reportError(Ctx, "Invalid value for RegisterSpace");
-
- if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 13))
- Sampler.ShaderVisibility = *Val;
- else
- return reportError(Ctx, "Invalid value for ShaderVisibility");
-
- RSD.StaticSamplers.push_back(Sampler);
- return false;
-}
-
-static bool parseRootSignatureElement(LLVMContext *Ctx,
- mcdxbc::RootSignatureDesc &RSD,
- MDNode *Element) {
- std::optional<StringRef> ElementText = extractMdStringValue(Element, 0);
- if (!ElementText.has_value())
- return reportError(Ctx, "Invalid format for Root Element");
-
- RootSignatureElementKind ElementKind =
- StringSwitch<RootSignatureElementKind>(*ElementText)
- .Case("RootFlags", RootSignatureElementKind::RootFlags)
- .Case("RootConstants", RootSignatureElementKind::RootConstants)
- .Case("RootCBV", RootSignatureElementKind::CBV)
- .Case("RootSRV", RootSignatureElementKind::SRV)
- .Case("RootUAV", RootSignatureElementKind::UAV)
- .Case("DescriptorTable", RootSignatureElementKind::DescriptorTable)
- .Case("StaticSampler", RootSignatureElementKind::StaticSamplers)
- .Default(RootSignatureElementKind::Error);
-
- switch (ElementKind) {
-
- case RootSignatureElementKind::RootFlags:
- return parseRootFlags(Ctx, RSD, Element);
- case RootSignatureElementKind::RootConstants:
- return parseRootConstants(Ctx, RSD, Element);
- case RootSignatureElementKind::CBV:
- case RootSignatureElementKind::SRV:
- case RootSignatureElementKind::UAV:
- return parseRootDescriptors(Ctx, RSD, Element, ElementKind);
- case RootSignatureElementKind::DescriptorTable:
- return parseDescriptorTable(Ctx, RSD, Element);
- case RootSignatureElementKind::StaticSamplers:
- return parseStaticSampler(Ctx, RSD, Element);
- case RootSignatureElementKind::Error:
- return reportError(Ctx, "Invalid Root Signature Element: " + *ElementText);
- }
-
- llvm_unreachable("Unhandled RootSignatureElementKind enum.");
-}
-
-static bool parse(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
- MDNode *Node) {
- bool HasError = false;
-
- // Loop through the Root Elements of the root signature.
- for (const auto &Operand : Node->operands()) {
- MDNode *Element = dyn_cast<MDNode>(Operand);
- if (Element == nullptr)
- return reportError(Ctx, "Missing Root Element Metadata Node.");
-
- HasError = HasError || parseRootSignatureElement(Ctx, RSD, Element);
- }
-
- return HasError;
-}
-
-static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) {
-
- if (!llvm::hlsl::rootsig::verifyVersion(RSD.Version)) {
- return reportValueError(Ctx, "Version", RSD.Version);
- }
-
- if (!llvm::hlsl::rootsig::verifyRootFlag(RSD.Flags)) {
- return reportValueError(Ctx, "RootFlags", RSD.Flags);
- }
-
- for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) {
- if (!dxbc::isValidShaderVisibility(Info.Header.ShaderVisibility))
- return reportValueError(Ctx, "ShaderVisibility",
- Info.Header.ShaderVisibility);
-
- assert(dxbc::isValidParameterType(Info.Header.ParameterType) &&
- "Invalid value for ParameterType");
-
- switch (Info.Header.ParameterType) {
-
- case llvm::to_underlying(dxbc::RootParameterType::CBV):
- case llvm::to_underlying(dxbc::RootParameterType::UAV):
- case llvm::to_underlying(dxbc::RootParameterType::SRV): {
- const dxbc::RTS0::v2::RootDescriptor &Descriptor =
- RSD.ParametersContainer.getRootDescriptor(Info.Location);
- if (!llvm::hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister))
- return reportValueError(Ctx, "ShaderRegister",
- Descriptor.ShaderRegister);
-
- if (!llvm::hlsl::rootsig::verifyRegisterSpace(Descriptor.RegisterSpace))
- return reportValueError(Ctx, "RegisterSpace", Descriptor.RegisterSpace);
-
- if (RSD.Version > 1) {
- if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(RSD.Version,
- Descriptor.Flags))
- return reportValueError(Ctx, "RootDescriptorFlag", Descriptor.Flags);
- }
- break;
- }
- case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
- const mcdxbc::DescriptorTable &Table =
- RSD.ParametersContainer.getDescriptorTable(Info.Location);
- for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) {
- if (!llvm::hlsl::rootsig::verifyRangeType(Range.RangeType))
- return reportValueError(Ctx, "RangeType", Range.RangeType);
-
- if (!llvm::hlsl::rootsig::verifyRegisterSpace(Range.RegisterSpace))
- return reportValueError(Ctx, "RegisterSpace", Range.RegisterSpace);
-
- if (!llvm::hlsl::rootsig::verifyNumDescriptors(Range.NumDescriptors))
- return reportValueError(Ctx, "NumDescriptors", Range.NumDescriptors);
-
- if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(
- RSD.Version, Range.RangeType, Range.Flags))
- return reportValueError(Ctx, "DescriptorFlag", Range.Flags);
- }
- break;
- }
- }
- }
-
- for (const dxbc::RTS0::v1::StaticSampler &Sampler : RSD.StaticSamplers) {
- if (!llvm::hlsl::rootsig::verifySamplerFilter(Sampler.Filter))
- return reportValueError(Ctx, "Filter", Sampler.Filter);
-
- if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressU))
- return reportValueError(Ctx, "AddressU", Sampler.AddressU);
-
- if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressV))
- return reportValueError(Ctx, "AddressV", Sampler.AddressV);
-
- if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressW))
- return reportValueError(Ctx, "AddressW", Sampler.AddressW);
-
- if (!llvm::hlsl::rootsig::verifyMipLODBias(Sampler.MipLODBias))
- return reportValueError(Ctx, "MipLODBias", Sampler.MipLODBias);
-
- if (!llvm::hlsl::rootsig::verifyMaxAnisotropy(Sampler.MaxAnisotropy))
- return reportValueError(Ctx, "MaxAnisotropy", Sampler.MaxAnisotropy);
-
- if (!llvm::hlsl::rootsig::verifyComparisonFunc(Sampler.ComparisonFunc))
- return reportValueError(Ctx, "ComparisonFunc", Sampler.ComparisonFunc);
-
- if (!llvm::hlsl::rootsig::verifyBorderColor(Sampler.BorderColor))
- return reportValueError(Ctx, "BorderColor", Sampler.BorderColor);
-
- if (!llvm::hlsl::rootsig::verifyLOD(Sampler.MinLOD))
- return reportValueError(Ctx, "MinLOD", Sampler.MinLOD);
-
- if (!llvm::hlsl::rootsig::verifyLOD(Sampler.MaxLOD))
- return reportValueError(Ctx, "MaxLOD", Sampler.MaxLOD);
-
- if (!llvm::hlsl::rootsig::verifyRegisterValue(Sampler.ShaderRegister))
- return reportValueError(Ctx, "ShaderRegister", Sampler.ShaderRegister);
-
- if (!llvm::hlsl::rootsig::verifyRegisterSpace(Sampler.RegisterSpace))
- return reportValueError(Ctx, "RegisterSpace", Sampler.RegisterSpace);
-
- if (!dxbc::isValidShaderVisibility(Sampler.ShaderVisibility))
- return reportValueError(Ctx, "ShaderVisibility",
- Sampler.ShaderVisibility);
- }
-
- return false;
-}
-
static SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc>
analyzeModule(Module &M) {
@@ -568,14 +103,26 @@ analyzeModule(Module &M) {
reportError(Ctx, "Root Element is not a metadata node.");
continue;
}
- mcdxbc::RootSignatureDesc RSD;
- if (std::optional<uint32_t> Version = extractMdIntValue(RSDefNode, 2))
- RSD.Version = *Version;
+ uint32_t Version = 1;
+ if (std::optional<uint32_t> V =
+ llvm::hlsl::rootsig::extractMdIntValue(RSDefNode, 2))
+ Version = *V;
else {
reportError(Ctx, "Invalid RSDefNode value, expected constant int");
continue;
}
+ llvm::hlsl::rootsig::MetadataParser MDParser(RootElementListNode);
+ llvm::Expected<mcdxbc::RootSignatureDesc> RSDOrErr =
+ MDParser.ParseRootSignature(Version);
+
+ if (auto Err = RSDOrErr.takeError()) {
+ reportError(Ctx, toString(std::move(Err)));
+ continue;
+ }
+
+ auto &RSD = *RSDOrErr;
+
// Clang emits the root signature data in dxcontainer following a specific
// sequence. First the header, then the root parameters. So the header
// offset will always equal to the header size.
@@ -584,10 +131,6 @@ analyzeModule(Module &M) {
// static sampler offset is calculated when writting dxcontainer.
RSD.StaticSamplersOffset = 0u;
- if (parse(Ctx, RSD, RootElementListNode) || validate(Ctx, RSD)) {
- return RSDMap;
- }
-
RSDMap.insert(std::make_pair(F, RSD));
}
diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.h b/llvm/lib/Target/DirectX/DXILRootSignature.h
index 76328bb15fa58..254b7ff504633 100644
--- a/llvm/lib/Target/DirectX/DXILRootSignature.h
+++ b/llvm/lib/Target/DirectX/DXILRootSignature.h
@@ -26,8 +26,6 @@
namespace llvm {
namespace dxil {
-
-
class RootSignatureBindingInfo {
private:
SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> FuncToRsMap;
diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-Invalid-RangeType.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-Invalid-RangeType.ll
index 0f7116307c315..644e4e4348980 100644
--- a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-Invalid-RangeType.ll
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-Invalid-RangeType.ll
@@ -2,7 +2,7 @@
target triple = "dxil-unknown-shadermodel6.0-compute"
-; CHECK: error: Invalid Descriptor Range type: Invalid
+; CHECK: error: Invalid Descriptor Range type:Invalid
; CHECK-NOT: Root Signature Definitions
define void @main() #0 {
diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags-Error.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags-Error.ll
index 65511160f230d..41e97701dcc20 100644
--- a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags-Error.ll
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags-Error.ll
@@ -2,7 +2,7 @@
target triple = "dxil-unknown-shadermodel6.0-compute"
-; CHECK: error: Invalid Root Signature Element: NOTRootFlags
+; CHECK: error: Invalid Root Signature Element:NOTRootFlags
; CHECK-NOT: Root Signature Definitions
diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootDescriptor-Invalid-RegisterKind.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootDescriptor-Invalid-RegisterKind.ll
index 579528d8b5e13..9e63b06674ebc 100644
--- a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootDescriptor-Invalid-RegisterKind.ll
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootDescriptor-Invalid-RegisterKind.ll
@@ -3,7 +3,7 @@
target triple = "dxil-unknown-shadermodel6.0-compute"
-; CHECK: error: Invalid Root Signature Element: Invalid
+; CHECK: error: Invalid Root Signature Element:Invalid
; CHECK-NOT: Root Signature Definitions
define void @main() #0 {
entry:
diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MaxLod.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MaxLod.ll
index 7e7d56eff309c..855e0c0cb6e51 100644
--- a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MaxLod.ll
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MaxLod.ll
@@ -3,7 +3,7 @@
target triple = "dxil-unknown-shadermodel6.0-compute"
-; CHECK: error: Invalid value for MaxLOD: 0
+; CHECK: error: Invalid value for MaxLOD: nan
; CHECK-NOT: Root Signature Definitions
define void @main() #0 {
diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MinLod.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MinLod.ll
index d958f10d3c1af..812749b9ed824 100644
--- a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MinLod.ll
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MinLod.ll
@@ -3,7 +3,7 @@
target triple = "dxil-unknown-shadermodel6.0-compute"
-; CHECK: error: Invalid value for MinLOD: 0
+; CHECK: error: Invalid value for MinLOD: nan
; CHECK-NOT: Root Signature Definitions
define void @main() #0 {
diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MinLopBias.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MinLopBias.ll
index 34b27eb40f5fb..6898aec6f2e49 100644
--- a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MinLopBias.ll
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MinLopBias.ll
@@ -3,7 +3,7 @@
target triple = "dxil-unknown-shadermodel6.0-compute"
-; CHECK: error: Invalid value for MipLODBias: 666
+; CHECK: error: Invalid value for MipLODBias: 6.660000e+02
; CHECK-NOT: Root Signature Definitions
define void @main() #0 {
>From ffc696e0e30c46c25296b9dad51fe92c22ac6b3a Mon Sep 17 00:00:00 2001
From: joaosaffran <joao.saffran at microsoft.com>
Date: Wed, 16 Jul 2025 23:08:31 +0000
Subject: [PATCH 03/10] clean up
---
llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h | 4 ----
llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp | 4 ----
2 files changed, 8 deletions(-)
diff --git a/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h b/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
index 6f337660ee6c8..729ea22d3c8ab 100644
--- a/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
+++ b/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
@@ -16,11 +16,7 @@
#include "llvm/Frontend/HLSL/HLSLRootSignature.h"
#include "llvm/IR/Constants.h"
-#include "llvm/IR/Function.h"
#include "llvm/MC/DXContainerRootSignature.h"
-#include "llvm/Support/Error.h"
-#include <cstdint>
-#include <unordered_map>
namespace llvm {
class LLVMContext;
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
index 89e130db796db..23c1815d438ad 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
@@ -13,13 +13,9 @@
#include "llvm/Frontend/HLSL/RootSignatureMetadata.h"
#include "llvm/Frontend/HLSL/RootSignatureValidations.h"
-#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Metadata.h"
-#include "llvm/Support/Error.h"
#include "llvm/Support/ScopedPrinter.h"
-#include <cstdint>
-#include <utility>
namespace llvm {
namespace hlsl {
>From d9c1c9638dd5d20d9d6c0dc583accb1046e1200b Mon Sep 17 00:00:00 2001
From: joaosaffran <joao.saffran at microsoft.com>
Date: Thu, 17 Jul 2025 01:04:19 +0000
Subject: [PATCH 04/10] keep only the move
---
.../Frontend/HLSL/RootSignatureMetadata.h | 131 ++-----
.../Frontend/HLSL/RootSignatureMetadata.cpp | 345 ++++++++----------
llvm/lib/Target/DirectX/DXILRootSignature.cpp | 48 +--
...ature-DescriptorTable-Invalid-RangeType.ll | 2 +-
.../RootSignature-Flags-Error.ll | 2 +-
...ure-RootDescriptor-Invalid-RegisterKind.ll | 2 +-
...Signature-StaticSamplers-Invalid-MaxLod.ll | 2 +-
...Signature-StaticSamplers-Invalid-MinLod.ll | 2 +-
...ature-StaticSamplers-Invalid-MinLopBias.ll | 2 +-
9 files changed, 204 insertions(+), 332 deletions(-)
diff --git a/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h b/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
index 729ea22d3c8ab..e0639530cd536 100644
--- a/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
+++ b/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
@@ -25,97 +25,6 @@ class Metadata;
namespace hlsl {
namespace rootsig {
-
-inline std::optional<uint32_t> extractMdIntValue(MDNode *Node,
- unsigned int OpId) {
- if (auto *CI =
- mdconst::dyn_extract<ConstantInt>(Node->getOperand(OpId).get()))
- return CI->getZExtValue();
- return std::nullopt;
-}
-
-inline std::optional<float> extractMdFloatValue(MDNode *Node,
- unsigned int OpId) {
- if (auto *CI = mdconst::dyn_extract<ConstantFP>(Node->getOperand(OpId).get()))
- return CI->getValueAPF().convertToFloat();
- return std::nullopt;
-}
-
-inline std::optional<StringRef> extractMdStringValue(MDNode *Node,
- unsigned int OpId) {
- MDString *NodeText = dyn_cast<MDString>(Node->getOperand(OpId));
- if (NodeText == nullptr)
- return std::nullopt;
- return NodeText->getString();
-}
-
-template <typename T>
-class RootSignatureValidationError
- : public ErrorInfo<RootSignatureValidationError<T>> {
-public:
- static char ID;
- std::string ParamName;
- T Value;
-
- RootSignatureValidationError(StringRef ParamName, T Value)
- : ParamName(ParamName.str()), Value(Value) {}
-
- void log(raw_ostream &OS) const override {
- OS << "Invalid value for " << ParamName << ": " << Value;
- }
-
- std::error_code convertToErrorCode() const override {
- return llvm::inconvertibleErrorCode();
- }
-};
-
-class GenericRSMetadataError : public ErrorInfo<GenericRSMetadataError> {
-public:
- static char ID;
- std::string Message;
-
- GenericRSMetadataError(Twine Message) : Message(Message.str()) {}
-
- void log(raw_ostream &OS) const override { OS << Message; }
-
- std::error_code convertToErrorCode() const override {
- return llvm::inconvertibleErrorCode();
- }
-};
-
-class InvalidRSMetadataFormat : public ErrorInfo<InvalidRSMetadataFormat> {
-public:
- static char ID;
- std::string ElementName;
-
- InvalidRSMetadataFormat(StringRef ElementName)
- : ElementName(ElementName.str()) {}
-
- void log(raw_ostream &OS) const override {
- OS << "Invalid format for " << ElementName;
- }
-
- std::error_code convertToErrorCode() const override {
- return llvm::inconvertibleErrorCode();
- }
-};
-
-class InvalidRSMetadataValue : public ErrorInfo<InvalidRSMetadataValue> {
-public:
- static char ID;
- std::string ParamName;
-
- InvalidRSMetadataValue(StringRef ParamName) : ParamName(ParamName.str()) {}
-
- void log(raw_ostream &OS) const override {
- OS << "Invalid value for " << ParamName;
- }
-
- std::error_code convertToErrorCode() const override {
- return llvm::inconvertibleErrorCode();
- }
-};
-
class MetadataBuilder {
public:
MetadataBuilder(llvm::LLVMContext &Ctx, ArrayRef<RootElement> Elements)
@@ -157,27 +66,29 @@ class MetadataParser {
MetadataParser(MDNode *Root) : Root(Root) {}
/// Iterates through root signature and converts them into MapT
- LLVM_ABI llvm::Expected<llvm::mcdxbc::RootSignatureDesc>
- ParseRootSignature(uint32_t Version);
+ LLVM_ABI bool ParseRootSignature(LLVMContext *Ctx,
+ mcdxbc::RootSignatureDesc &RSD);
private:
- llvm::Error parseRootFlags(mcdxbc::RootSignatureDesc &RSD,
- MDNode *RootFlagNode);
- llvm::Error parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
- MDNode *RootConstantNode);
- llvm::Error parseRootDescriptors(mcdxbc::RootSignatureDesc &RSD,
- MDNode *RootDescriptorNode,
- RootSignatureElementKind ElementKind);
- llvm::Error parseDescriptorRange(mcdxbc::DescriptorTable &Table,
- MDNode *RangeDescriptorNode);
- llvm::Error parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
- MDNode *DescriptorTableNode);
- llvm::Error parseRootSignatureElement(mcdxbc::RootSignatureDesc &RSD,
- MDNode *Element);
- llvm::Error parseStaticSampler(mcdxbc::RootSignatureDesc &RSD,
- MDNode *StaticSamplerNode);
-
- llvm::Error validateRootSignature(const llvm::mcdxbc::RootSignatureDesc &RSD);
+ bool parseRootFlags(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
+ MDNode *RootFlagNode);
+ bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
+ MDNode *RootConstantNode);
+ bool parseRootDescriptors(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
+ MDNode *RootDescriptorNode,
+ RootSignatureElementKind ElementKind);
+ bool parseDescriptorRange(LLVMContext *Ctx, mcdxbc::DescriptorTable &Table,
+ MDNode *RangeDescriptorNode);
+ bool parseDescriptorTable(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
+ MDNode *DescriptorTableNode);
+ bool parseRootSignatureElement(LLVMContext *Ctx,
+ mcdxbc::RootSignatureDesc &RSD,
+ MDNode *Element);
+ bool parseStaticSampler(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
+ MDNode *StaticSamplerNode);
+
+ bool validateRootSignature(LLVMContext *Ctx,
+ const llvm::mcdxbc::RootSignatureDesc &RSD);
MDNode *Root;
};
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
index 23c1815d438ad..53f59349ae029 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
@@ -13,6 +13,7 @@
#include "llvm/Frontend/HLSL/RootSignatureMetadata.h"
#include "llvm/Frontend/HLSL/RootSignatureValidations.h"
+#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Metadata.h"
#include "llvm/Support/ScopedPrinter.h"
@@ -21,11 +22,41 @@ namespace llvm {
namespace hlsl {
namespace rootsig {
-char GenericRSMetadataError::ID;
-char InvalidRSMetadataFormat::ID;
-char InvalidRSMetadataValue::ID;
+static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
+ unsigned int OpId) {
+ if (auto *CI =
+ mdconst::dyn_extract<ConstantInt>(Node->getOperand(OpId).get()))
+ return CI->getZExtValue();
+ return std::nullopt;
+}
-template <typename T> char RootSignatureValidationError<T>::ID;
+static std::optional<float> extractMdFloatValue(MDNode *Node,
+ unsigned int OpId) {
+ if (auto *CI = mdconst::dyn_extract<ConstantFP>(Node->getOperand(OpId).get()))
+ return CI->getValueAPF().convertToFloat();
+ return std::nullopt;
+}
+
+static std::optional<StringRef> extractMdStringValue(MDNode *Node,
+ unsigned int OpId) {
+ MDString *NodeText = dyn_cast<MDString>(Node->getOperand(OpId));
+ if (NodeText == nullptr)
+ return std::nullopt;
+ return NodeText->getString();
+}
+
+static bool reportError(LLVMContext *Ctx, Twine Message,
+ DiagnosticSeverity Severity = DS_Error) {
+ Ctx->diagnose(DiagnosticInfoGeneric(Message, Severity));
+ return true;
+}
+
+static bool reportValueError(LLVMContext *Ctx, Twine ParamName,
+ uint32_t Value) {
+ Ctx->diagnose(DiagnosticInfoGeneric(
+ "Invalid value for " + ParamName + ": " + Twine(Value), DS_Error));
+ return true;
+}
static const EnumEntry<dxil::ResourceClass> ResourceClassNames[] = {
{"CBV", dxil::ResourceClass::CBuffer},
@@ -196,23 +227,27 @@ MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) {
return MDNode::get(Ctx, Operands);
}
-llvm::Error MetadataParser::parseRootFlags(mcdxbc::RootSignatureDesc &RSD,
- MDNode *RootFlagNode) {
+bool MetadataParser::parseRootFlags(LLVMContext *Ctx,
+ mcdxbc::RootSignatureDesc &RSD,
+ MDNode *RootFlagNode) {
+
if (RootFlagNode->getNumOperands() != 2)
- return make_error<InvalidRSMetadataFormat>("RootFlag Element");
+ return reportError(Ctx, "Invalid format for RootFlag Element");
if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1))
RSD.Flags = *Val;
else
- return make_error<InvalidRSMetadataValue>("RootFlag");
+ return reportError(Ctx, "Invalid value for RootFlag");
- return llvm::Error::success();
+ return false;
}
-llvm::Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
- MDNode *RootConstantNode) {
+bool MetadataParser::parseRootConstants(LLVMContext *Ctx,
+ mcdxbc::RootSignatureDesc &RSD,
+ MDNode *RootConstantNode) {
+
if (RootConstantNode->getNumOperands() != 5)
- return make_error<InvalidRSMetadataFormat>("RootConstants Element");
+ return reportError(Ctx, "Invalid format for RootConstants Element");
dxbc::RTS0::v1::RootParameterHeader Header;
// The parameter offset doesn't matter here - we recalculate it during
@@ -223,40 +258,39 @@ llvm::Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1))
Header.ShaderVisibility = *Val;
else
- return make_error<InvalidRSMetadataValue>("ShaderVisibility");
+ return reportError(Ctx, "Invalid value for ShaderVisibility");
dxbc::RTS0::v1::RootConstants Constants;
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2))
Constants.ShaderRegister = *Val;
else
- return make_error<InvalidRSMetadataValue>("ShaderRegister");
+ return reportError(Ctx, "Invalid value for ShaderRegister");
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3))
Constants.RegisterSpace = *Val;
else
- return make_error<InvalidRSMetadataValue>("RegisterSpace");
+ return reportError(Ctx, "Invalid value for RegisterSpace");
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4))
Constants.Num32BitValues = *Val;
else
- return make_error<InvalidRSMetadataValue>("Num32BitValues");
+ return reportError(Ctx, "Invalid value for Num32BitValues");
RSD.ParametersContainer.addParameter(Header, Constants);
- return llvm::Error::success();
+ return false;
}
-llvm::Error
-MetadataParser::parseRootDescriptors(mcdxbc::RootSignatureDesc &RSD,
- MDNode *RootDescriptorNode,
- RootSignatureElementKind ElementKind) {
+bool MetadataParser::parseRootDescriptors(
+ LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
+ MDNode *RootDescriptorNode, RootSignatureElementKind ElementKind) {
assert(ElementKind == RootSignatureElementKind::SRV ||
ElementKind == RootSignatureElementKind::UAV ||
ElementKind == RootSignatureElementKind::CBV &&
- "parseRootDescriptors should only be called with RootDescriptor"
+ "parseRootDescriptors should only be called with RootDescriptor "
"element kind.");
if (RootDescriptorNode->getNumOperands() != 5)
- return make_error<InvalidRSMetadataFormat>("Root Descriptor Element");
+ return reportError(Ctx, "Invalid format for Root Descriptor Element");
dxbc::RTS0::v1::RootParameterHeader Header;
switch (ElementKind) {
@@ -277,38 +311,40 @@ MetadataParser::parseRootDescriptors(mcdxbc::RootSignatureDesc &RSD,
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 1))
Header.ShaderVisibility = *Val;
else
- return make_error<InvalidRSMetadataValue>("ShaderVisibility");
+ return reportError(Ctx, "Invalid value for ShaderVisibility");
dxbc::RTS0::v2::RootDescriptor Descriptor;
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2))
Descriptor.ShaderRegister = *Val;
else
- return make_error<InvalidRSMetadataValue>("ShaderRegister");
+ return reportError(Ctx, "Invalid value for ShaderRegister");
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 3))
Descriptor.RegisterSpace = *Val;
else
- return make_error<InvalidRSMetadataValue>("RegisterSpace");
+ return reportError(Ctx, "Invalid value for RegisterSpace");
if (RSD.Version == 1) {
RSD.ParametersContainer.addParameter(Header, Descriptor);
- return llvm::Error::success();
+ return false;
}
assert(RSD.Version > 1);
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 4))
Descriptor.Flags = *Val;
else
- return make_error<InvalidRSMetadataValue>("Root Descriptor Flags");
+ return reportError(Ctx, "Invalid value for Root Descriptor Flags");
RSD.ParametersContainer.addParameter(Header, Descriptor);
- return llvm::Error::success();
+ return false;
}
-llvm::Error MetadataParser::parseDescriptorRange(mcdxbc::DescriptorTable &Table,
- MDNode *RangeDescriptorNode) {
+bool MetadataParser::parseDescriptorRange(LLVMContext *Ctx,
+ mcdxbc::DescriptorTable &Table,
+ MDNode *RangeDescriptorNode) {
+
if (RangeDescriptorNode->getNumOperands() != 6)
- return make_error<InvalidRSMetadataFormat>("Descriptor Range");
+ return reportError(Ctx, "Invalid format for Descriptor Range");
dxbc::RTS0::v2::DescriptorRange Range;
@@ -316,7 +352,7 @@ llvm::Error MetadataParser::parseDescriptorRange(mcdxbc::DescriptorTable &Table,
extractMdStringValue(RangeDescriptorNode, 0);
if (!ElementText.has_value())
- return make_error<InvalidRSMetadataFormat>("Descriptor Range");
+ return reportError(Ctx, "Descriptor Range, first element is not a string.");
Range.RangeType =
StringSwitch<uint32_t>(*ElementText)
@@ -328,50 +364,50 @@ llvm::Error MetadataParser::parseDescriptorRange(mcdxbc::DescriptorTable &Table,
.Default(~0U);
if (Range.RangeType == ~0U)
- return make_error<GenericRSMetadataError>("Invalid Descriptor Range type:" +
- *ElementText);
+ return reportError(Ctx, "Invalid Descriptor Range type: " + *ElementText);
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 1))
Range.NumDescriptors = *Val;
else
- return make_error<GenericRSMetadataError>("Number of Descriptor in Range");
+ return reportError(Ctx, "Invalid value for Number of Descriptor in Range");
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 2))
Range.BaseShaderRegister = *Val;
else
- return make_error<InvalidRSMetadataValue>("BaseShaderRegister");
+ return reportError(Ctx, "Invalid value for BaseShaderRegister");
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 3))
Range.RegisterSpace = *Val;
else
- return make_error<InvalidRSMetadataValue>("RegisterSpace");
+ return reportError(Ctx, "Invalid value for RegisterSpace");
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 4))
Range.OffsetInDescriptorsFromTableStart = *Val;
else
- return make_error<InvalidRSMetadataValue>(
- "OffsetInDescriptorsFromTableStart");
+ return reportError(Ctx,
+ "Invalid value for OffsetInDescriptorsFromTableStart");
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 5))
Range.Flags = *Val;
else
- return make_error<InvalidRSMetadataValue>("Descriptor Range Flags");
+ return reportError(Ctx, "Invalid value for Descriptor Range Flags");
Table.Ranges.push_back(Range);
- return llvm::Error::success();
+ return false;
}
-llvm::Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
- MDNode *DescriptorTableNode) {
+bool MetadataParser::parseDescriptorTable(LLVMContext *Ctx,
+ mcdxbc::RootSignatureDesc &RSD,
+ MDNode *DescriptorTableNode) {
const unsigned int NumOperands = DescriptorTableNode->getNumOperands();
if (NumOperands < 2)
- return make_error<InvalidRSMetadataFormat>("Descriptor Table");
+ return reportError(Ctx, "Invalid format for Descriptor Table");
dxbc::RTS0::v1::RootParameterHeader Header;
if (std::optional<uint32_t> Val = extractMdIntValue(DescriptorTableNode, 1))
Header.ShaderVisibility = *Val;
else
- return make_error<InvalidRSMetadataValue>("ShaderVisibility");
+ return reportError(Ctx, "Invalid value for ShaderVisibility");
mcdxbc::DescriptorTable Table;
Header.ParameterType =
@@ -380,98 +416,98 @@ llvm::Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
for (unsigned int I = 2; I < NumOperands; I++) {
MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I));
if (Element == nullptr)
- return make_error<GenericRSMetadataError>(
- "Missing Root Element Metadata Node.");
+ return reportError(Ctx, "Missing Root Element Metadata Node.");
- if (auto Err = parseDescriptorRange(Table, Element))
- return Err;
+ if (parseDescriptorRange(Ctx, Table, Element))
+ return true;
}
RSD.ParametersContainer.addParameter(Header, Table);
- return llvm::Error::success();
+ return false;
}
-llvm::Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD,
- MDNode *StaticSamplerNode) {
+bool MetadataParser::parseStaticSampler(LLVMContext *Ctx,
+ mcdxbc::RootSignatureDesc &RSD,
+ MDNode *StaticSamplerNode) {
if (StaticSamplerNode->getNumOperands() != 14)
- return make_error<InvalidRSMetadataFormat>("Static Sampler");
+ return reportError(Ctx, "Invalid format for Static Sampler");
dxbc::RTS0::v1::StaticSampler Sampler;
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 1))
Sampler.Filter = *Val;
else
- return make_error<InvalidRSMetadataValue>("Filter");
+ return reportError(Ctx, "Invalid value for Filter");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 2))
Sampler.AddressU = *Val;
else
- return make_error<InvalidRSMetadataValue>("AddressU");
+ return reportError(Ctx, "Invalid value for AddressU");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 3))
Sampler.AddressV = *Val;
else
- return make_error<InvalidRSMetadataValue>("AddressV");
+ return reportError(Ctx, "Invalid value for AddressV");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 4))
Sampler.AddressW = *Val;
else
- return make_error<InvalidRSMetadataValue>("AddressW");
+ return reportError(Ctx, "Invalid value for AddressW");
if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 5))
Sampler.MipLODBias = *Val;
else
- return make_error<InvalidRSMetadataValue>("MipLODBias");
+ return reportError(Ctx, "Invalid value for MipLODBias");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 6))
Sampler.MaxAnisotropy = *Val;
else
- return make_error<InvalidRSMetadataValue>("MaxAnisotropy");
+ return reportError(Ctx, "Invalid value for MaxAnisotropy");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 7))
Sampler.ComparisonFunc = *Val;
else
- return make_error<InvalidRSMetadataValue>("ComparisonFunc");
+ return reportError(Ctx, "Invalid value for ComparisonFunc ");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 8))
Sampler.BorderColor = *Val;
else
- return make_error<InvalidRSMetadataValue>("ComparisonFunc");
+ return reportError(Ctx, "Invalid value for ComparisonFunc ");
if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 9))
Sampler.MinLOD = *Val;
else
- return make_error<InvalidRSMetadataValue>("MinLOD");
+ return reportError(Ctx, "Invalid value for MinLOD");
if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 10))
Sampler.MaxLOD = *Val;
else
- return make_error<InvalidRSMetadataValue>("MaxLOD");
+ return reportError(Ctx, "Invalid value for MaxLOD");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 11))
Sampler.ShaderRegister = *Val;
else
- return make_error<InvalidRSMetadataValue>("ShaderRegister");
+ return reportError(Ctx, "Invalid value for ShaderRegister");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 12))
Sampler.RegisterSpace = *Val;
else
- return make_error<InvalidRSMetadataValue>("RegisterSpace");
+ return reportError(Ctx, "Invalid value for RegisterSpace");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 13))
Sampler.ShaderVisibility = *Val;
else
- return make_error<InvalidRSMetadataValue>("ShaderVisibility");
+ return reportError(Ctx, "Invalid value for ShaderVisibility");
RSD.StaticSamplers.push_back(Sampler);
- return llvm::Error::success();
+ return false;
}
-llvm::Error
-MetadataParser::parseRootSignatureElement(mcdxbc::RootSignatureDesc &RSD,
- MDNode *Element) {
+bool MetadataParser::parseRootSignatureElement(LLVMContext *Ctx,
+ mcdxbc::RootSignatureDesc &RSD,
+ MDNode *Element) {
std::optional<StringRef> ElementText = extractMdStringValue(Element, 0);
if (!ElementText.has_value())
- return make_error<InvalidRSMetadataFormat>("Root Element");
+ return reportError(Ctx, "Invalid format for Root Element");
RootSignatureElementKind ElementKind =
StringSwitch<RootSignatureElementKind>(*ElementText)
@@ -487,48 +523,38 @@ MetadataParser::parseRootSignatureElement(mcdxbc::RootSignatureDesc &RSD,
switch (ElementKind) {
case RootSignatureElementKind::RootFlags:
- return parseRootFlags(RSD, Element);
+ return parseRootFlags(Ctx, RSD, Element);
case RootSignatureElementKind::RootConstants:
- return parseRootConstants(RSD, Element);
+ return parseRootConstants(Ctx, RSD, Element);
case RootSignatureElementKind::CBV:
case RootSignatureElementKind::SRV:
case RootSignatureElementKind::UAV:
- return parseRootDescriptors(RSD, Element, ElementKind);
+ return parseRootDescriptors(Ctx, RSD, Element, ElementKind);
case RootSignatureElementKind::DescriptorTable:
- return parseDescriptorTable(RSD, Element);
+ return parseDescriptorTable(Ctx, RSD, Element);
case RootSignatureElementKind::StaticSamplers:
- return parseStaticSampler(RSD, Element);
+ return parseStaticSampler(Ctx, RSD, Element);
case RootSignatureElementKind::Error:
- return make_error<GenericRSMetadataError>(
- "Invalid Root Signature Element:" + *ElementText);
+ return reportError(Ctx, "Invalid Root Signature Element: " + *ElementText);
}
llvm_unreachable("Unhandled RootSignatureElementKind enum.");
}
-llvm::Error MetadataParser::validateRootSignature(
- const llvm::mcdxbc::RootSignatureDesc &RSD) {
- Error DeferredErrs = Error::success();
+bool MetadataParser::validateRootSignature(
+ LLVMContext *Ctx, const llvm::mcdxbc::RootSignatureDesc &RSD) {
if (!llvm::hlsl::rootsig::verifyVersion(RSD.Version)) {
- DeferredErrs =
- joinErrors(std::move(DeferredErrs),
- llvm::make_error<RootSignatureValidationError<uint32_t>>(
- "Version", RSD.Version));
+ return reportValueError(Ctx, "Version", RSD.Version);
}
if (!llvm::hlsl::rootsig::verifyRootFlag(RSD.Flags)) {
- DeferredErrs =
- joinErrors(std::move(DeferredErrs),
- llvm::make_error<RootSignatureValidationError<uint32_t>>(
- "RootFlags", RSD.Flags));
+ return reportValueError(Ctx, "RootFlags", RSD.Flags);
}
for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) {
if (!dxbc::isValidShaderVisibility(Info.Header.ShaderVisibility))
- DeferredErrs =
- joinErrors(std::move(DeferredErrs),
- llvm::make_error<RootSignatureValidationError<uint32_t>>(
- "ShaderVisibility", Info.Header.ShaderVisibility));
+ return reportValueError(Ctx, "ShaderVisibility",
+ Info.Header.ShaderVisibility);
assert(dxbc::isValidParameterType(Info.Header.ParameterType) &&
"Invalid value for ParameterType");
@@ -541,24 +567,16 @@ llvm::Error MetadataParser::validateRootSignature(
const dxbc::RTS0::v2::RootDescriptor &Descriptor =
RSD.ParametersContainer.getRootDescriptor(Info.Location);
if (!llvm::hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister))
- DeferredErrs =
- joinErrors(std::move(DeferredErrs),
- llvm::make_error<RootSignatureValidationError<uint32_t>>(
- "ShaderRegister", Descriptor.ShaderRegister));
+ return reportValueError(Ctx, "ShaderRegister",
+ Descriptor.ShaderRegister);
if (!llvm::hlsl::rootsig::verifyRegisterSpace(Descriptor.RegisterSpace))
- DeferredErrs =
- joinErrors(std::move(DeferredErrs),
- llvm::make_error<RootSignatureValidationError<uint32_t>>(
- "RegisterSpace", Descriptor.RegisterSpace));
+ return reportValueError(Ctx, "RegisterSpace", Descriptor.RegisterSpace);
if (RSD.Version > 1) {
if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(RSD.Version,
Descriptor.Flags))
- DeferredErrs = joinErrors(
- std::move(DeferredErrs),
- llvm::make_error<RootSignatureValidationError<uint32_t>>(
- "RootDescriptorFlag", Descriptor.Flags));
+ return reportValueError(Ctx, "RootDescriptorFlag", Descriptor.Flags);
}
break;
}
@@ -567,29 +585,17 @@ llvm::Error MetadataParser::validateRootSignature(
RSD.ParametersContainer.getDescriptorTable(Info.Location);
for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) {
if (!llvm::hlsl::rootsig::verifyRangeType(Range.RangeType))
- DeferredErrs = joinErrors(
- std::move(DeferredErrs),
- llvm::make_error<RootSignatureValidationError<uint32_t>>(
- "RangeType", Range.RangeType));
+ return reportValueError(Ctx, "RangeType", Range.RangeType);
if (!llvm::hlsl::rootsig::verifyRegisterSpace(Range.RegisterSpace))
- DeferredErrs = joinErrors(
- std::move(DeferredErrs),
- llvm::make_error<RootSignatureValidationError<uint32_t>>(
- "RegisterSpace", Range.RegisterSpace));
+ return reportValueError(Ctx, "RegisterSpace", Range.RegisterSpace);
if (!llvm::hlsl::rootsig::verifyNumDescriptors(Range.NumDescriptors))
- DeferredErrs = joinErrors(
- std::move(DeferredErrs),
- llvm::make_error<RootSignatureValidationError<uint32_t>>(
- "NumDescriptors", Range.NumDescriptors));
+ return reportValueError(Ctx, "NumDescriptors", Range.NumDescriptors);
if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(
RSD.Version, Range.RangeType, Range.Flags))
- DeferredErrs = joinErrors(
- std::move(DeferredErrs),
- llvm::make_error<RootSignatureValidationError<uint32_t>>(
- "DescriptorFlag", Range.Flags));
+ return reportValueError(Ctx, "DescriptorFlag", Range.Flags);
}
break;
}
@@ -598,111 +604,64 @@ llvm::Error MetadataParser::validateRootSignature(
for (const dxbc::RTS0::v1::StaticSampler &Sampler : RSD.StaticSamplers) {
if (!llvm::hlsl::rootsig::verifySamplerFilter(Sampler.Filter))
- DeferredErrs =
- joinErrors(std::move(DeferredErrs),
- llvm::make_error<RootSignatureValidationError<uint32_t>>(
- "Filter", Sampler.Filter));
+ return reportValueError(Ctx, "Filter", Sampler.Filter);
if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressU))
- DeferredErrs =
- joinErrors(std::move(DeferredErrs),
- llvm::make_error<RootSignatureValidationError<uint32_t>>(
- "AddressU", Sampler.AddressU));
+ return reportValueError(Ctx, "AddressU", Sampler.AddressU);
if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressV))
- DeferredErrs =
- joinErrors(std::move(DeferredErrs),
- llvm::make_error<RootSignatureValidationError<uint32_t>>(
- "AddressV", Sampler.AddressV));
+ return reportValueError(Ctx, "AddressV", Sampler.AddressV);
if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressW))
- DeferredErrs =
- joinErrors(std::move(DeferredErrs),
- llvm::make_error<RootSignatureValidationError<uint32_t>>(
- "AddressW", Sampler.AddressW));
+ return reportValueError(Ctx, "AddressW", Sampler.AddressW);
if (!llvm::hlsl::rootsig::verifyMipLODBias(Sampler.MipLODBias))
- DeferredErrs =
- joinErrors(std::move(DeferredErrs),
- llvm::make_error<RootSignatureValidationError<float>>(
- "MipLODBias", Sampler.MipLODBias));
+ return reportValueError(Ctx, "MipLODBias", Sampler.MipLODBias);
if (!llvm::hlsl::rootsig::verifyMaxAnisotropy(Sampler.MaxAnisotropy))
- DeferredErrs =
- joinErrors(std::move(DeferredErrs),
- llvm::make_error<RootSignatureValidationError<uint32_t>>(
- "MaxAnisotropy", Sampler.MaxAnisotropy));
+ return reportValueError(Ctx, "MaxAnisotropy", Sampler.MaxAnisotropy);
if (!llvm::hlsl::rootsig::verifyComparisonFunc(Sampler.ComparisonFunc))
- DeferredErrs =
- joinErrors(std::move(DeferredErrs),
- llvm::make_error<RootSignatureValidationError<uint32_t>>(
- "ComparisonFunc", Sampler.ComparisonFunc));
+ return reportValueError(Ctx, "ComparisonFunc", Sampler.ComparisonFunc);
if (!llvm::hlsl::rootsig::verifyBorderColor(Sampler.BorderColor))
- DeferredErrs =
- joinErrors(std::move(DeferredErrs),
- llvm::make_error<RootSignatureValidationError<uint32_t>>(
- "BorderColor", Sampler.BorderColor));
+ return reportValueError(Ctx, "BorderColor", Sampler.BorderColor);
if (!llvm::hlsl::rootsig::verifyLOD(Sampler.MinLOD))
- DeferredErrs =
- joinErrors(std::move(DeferredErrs),
- llvm::make_error<RootSignatureValidationError<float>>(
- "MinLOD", Sampler.MinLOD));
+ return reportValueError(Ctx, "MinLOD", Sampler.MinLOD);
if (!llvm::hlsl::rootsig::verifyLOD(Sampler.MaxLOD))
- DeferredErrs =
- joinErrors(std::move(DeferredErrs),
- llvm::make_error<RootSignatureValidationError<float>>(
- "MaxLOD", Sampler.MaxLOD));
+ return reportValueError(Ctx, "MaxLOD", Sampler.MaxLOD);
if (!llvm::hlsl::rootsig::verifyRegisterValue(Sampler.ShaderRegister))
- DeferredErrs =
- joinErrors(std::move(DeferredErrs),
- llvm::make_error<RootSignatureValidationError<uint32_t>>(
- "ShaderRegister", Sampler.ShaderRegister));
+ return reportValueError(Ctx, "ShaderRegister", Sampler.ShaderRegister);
if (!llvm::hlsl::rootsig::verifyRegisterSpace(Sampler.RegisterSpace))
- DeferredErrs =
- joinErrors(std::move(DeferredErrs),
- llvm::make_error<RootSignatureValidationError<uint32_t>>(
- "RegisterSpace", Sampler.RegisterSpace));
+ return reportValueError(Ctx, "RegisterSpace", Sampler.RegisterSpace);
if (!dxbc::isValidShaderVisibility(Sampler.ShaderVisibility))
- DeferredErrs =
- joinErrors(std::move(DeferredErrs),
- llvm::make_error<RootSignatureValidationError<uint32_t>>(
- "ShaderVisibility", Sampler.ShaderVisibility));
+ return reportValueError(Ctx, "ShaderVisibility",
+ Sampler.ShaderVisibility);
}
- return DeferredErrs;
+ return false;
}
-llvm::Expected<mcdxbc::RootSignatureDesc>
-MetadataParser::ParseRootSignature(uint32_t Version) {
- Error DeferredErrs = Error::success();
- mcdxbc::RootSignatureDesc RSD;
- RSD.Version = Version;
+bool MetadataParser::ParseRootSignature(LLVMContext *Ctx,
+ mcdxbc::RootSignatureDesc &RSD) {
+ bool HasError = false;
+
+ // Loop through the Root Elements of the root signature.
for (const auto &Operand : Root->operands()) {
MDNode *Element = dyn_cast<MDNode>(Operand);
if (Element == nullptr)
- return joinErrors(std::move(DeferredErrs),
- make_error<GenericRSMetadataError>(
- "Missing Root Element Metadata Node."));
+ return reportError(Ctx, "Missing Root Element Metadata Node.");
- if (auto Err = parseRootSignatureElement(RSD, Element)) {
- DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err));
- }
+ HasError = HasError || parseRootSignatureElement(Ctx, RSD, Element) ||
+ validateRootSignature(Ctx, RSD);
}
- if (auto Err = validateRootSignature(RSD))
- DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err));
-
- if (DeferredErrs)
- return std::move(DeferredErrs);
-
- return std::move(RSD);
+ return HasError;
}
} // namespace rootsig
} // namespace hlsl
diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
index 6459294169fe3..f5e51978e5a4e 100644
--- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp
+++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
@@ -34,6 +34,14 @@
using namespace llvm;
using namespace llvm::dxil;
+static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
+ unsigned int OpId) {
+ if (auto *CI =
+ mdconst::dyn_extract<ConstantInt>(Node->getOperand(OpId).get()))
+ return CI->getZExtValue();
+ return std::nullopt;
+}
+
static bool reportError(LLVMContext *Ctx, Twine Message,
DiagnosticSeverity Severity = DS_Error) {
Ctx->diagnose(DiagnosticInfoGeneric(Message, Severity));
@@ -45,15 +53,15 @@ analyzeModule(Module &M) {
/** Root Signature are specified as following in the metadata:
- !dx.rootsignatures = !{!2} ; list of function/root signature pairs
- !2 = !{ ptr @main, !3 } ; function, root signature
- !3 = !{ !4, !5, !6, !7 } ; list of root signature elements
+ !dx.rootsignatures = !{!2} ; list of function/root signature pairs
+ !2 = !{ ptr @main, !3 } ; function, root signature
+ !3 = !{ !4, !5, !6, !7 } ; list of root signature elements
- So for each MDNode inside dx.rootsignatures NamedMDNode
- (the Root parameter of this function), the parsing process needs
- to loop through each of its operands and process the function,
- signature pair.
- */
+ So for each MDNode inside dx.rootsignatures NamedMDNode
+ (the Root parameter of this function), the parsing process needs
+ to loop through each of its operands and process the function,
+ signature pair.
+*/
LLVMContext *Ctx = &M.getContext();
@@ -103,26 +111,14 @@ analyzeModule(Module &M) {
reportError(Ctx, "Root Element is not a metadata node.");
continue;
}
- uint32_t Version = 1;
- if (std::optional<uint32_t> V =
- llvm::hlsl::rootsig::extractMdIntValue(RSDefNode, 2))
- Version = *V;
+ mcdxbc::RootSignatureDesc RSD;
+ if (std::optional<uint32_t> Version = extractMdIntValue(RSDefNode, 2))
+ RSD.Version = *Version;
else {
reportError(Ctx, "Invalid RSDefNode value, expected constant int");
continue;
}
- llvm::hlsl::rootsig::MetadataParser MDParser(RootElementListNode);
- llvm::Expected<mcdxbc::RootSignatureDesc> RSDOrErr =
- MDParser.ParseRootSignature(Version);
-
- if (auto Err = RSDOrErr.takeError()) {
- reportError(Ctx, toString(std::move(Err)));
- continue;
- }
-
- auto &RSD = *RSDOrErr;
-
// Clang emits the root signature data in dxcontainer following a specific
// sequence. First the header, then the root parameters. So the header
// offset will always equal to the header size.
@@ -131,6 +127,12 @@ analyzeModule(Module &M) {
// static sampler offset is calculated when writting dxcontainer.
RSD.StaticSamplersOffset = 0u;
+ hlsl::rootsig::MetadataParser MDParser(RootElementListNode);
+
+ if (MDParser.ParseRootSignature(Ctx, RSD)) {
+ return RSDMap;
+ }
+
RSDMap.insert(std::make_pair(F, RSD));
}
diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-Invalid-RangeType.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-Invalid-RangeType.ll
index 644e4e4348980..0f7116307c315 100644
--- a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-Invalid-RangeType.ll
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-Invalid-RangeType.ll
@@ -2,7 +2,7 @@
target triple = "dxil-unknown-shadermodel6.0-compute"
-; CHECK: error: Invalid Descriptor Range type:Invalid
+; CHECK: error: Invalid Descriptor Range type: Invalid
; CHECK-NOT: Root Signature Definitions
define void @main() #0 {
diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags-Error.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags-Error.ll
index 41e97701dcc20..65511160f230d 100644
--- a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags-Error.ll
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags-Error.ll
@@ -2,7 +2,7 @@
target triple = "dxil-unknown-shadermodel6.0-compute"
-; CHECK: error: Invalid Root Signature Element:NOTRootFlags
+; CHECK: error: Invalid Root Signature Element: NOTRootFlags
; CHECK-NOT: Root Signature Definitions
diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootDescriptor-Invalid-RegisterKind.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootDescriptor-Invalid-RegisterKind.ll
index 9e63b06674ebc..579528d8b5e13 100644
--- a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootDescriptor-Invalid-RegisterKind.ll
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootDescriptor-Invalid-RegisterKind.ll
@@ -3,7 +3,7 @@
target triple = "dxil-unknown-shadermodel6.0-compute"
-; CHECK: error: Invalid Root Signature Element:Invalid
+; CHECK: error: Invalid Root Signature Element: Invalid
; CHECK-NOT: Root Signature Definitions
define void @main() #0 {
entry:
diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MaxLod.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MaxLod.ll
index 855e0c0cb6e51..7e7d56eff309c 100644
--- a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MaxLod.ll
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MaxLod.ll
@@ -3,7 +3,7 @@
target triple = "dxil-unknown-shadermodel6.0-compute"
-; CHECK: error: Invalid value for MaxLOD: nan
+; CHECK: error: Invalid value for MaxLOD: 0
; CHECK-NOT: Root Signature Definitions
define void @main() #0 {
diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MinLod.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MinLod.ll
index 812749b9ed824..d958f10d3c1af 100644
--- a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MinLod.ll
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MinLod.ll
@@ -3,7 +3,7 @@
target triple = "dxil-unknown-shadermodel6.0-compute"
-; CHECK: error: Invalid value for MinLOD: nan
+; CHECK: error: Invalid value for MinLOD: 0
; CHECK-NOT: Root Signature Definitions
define void @main() #0 {
diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MinLopBias.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MinLopBias.ll
index 6898aec6f2e49..dc4cb8987e777 100644
--- a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MinLopBias.ll
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MinLopBias.ll
@@ -3,7 +3,7 @@
target triple = "dxil-unknown-shadermodel6.0-compute"
-; CHECK: error: Invalid value for MipLODBias: 6.660000e+02
+; CHECK: error: Invalid value for MipLODBias: 666
; CHECK-NOT: Root Signature Definitions
define void @main() #0 {
>From ca779829c86054a6955cbe51d064a1b8085580cb Mon Sep 17 00:00:00 2001
From: joaosaffran <joao.saffran at microsoft.com>
Date: Thu, 17 Jul 2025 01:07:40 +0000
Subject: [PATCH 05/10] clean
---
.../llvm/Frontend/HLSL/RootSignatureMetadata.h | 1 +
llvm/lib/Target/DirectX/DXILRootSignature.cpp | 16 ++++++++--------
...ignature-StaticSamplers-Invalid-MinLopBias.ll | 2 +-
3 files changed, 10 insertions(+), 9 deletions(-)
diff --git a/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h b/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
index e0639530cd536..cd5966db42b5f 100644
--- a/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
+++ b/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
@@ -25,6 +25,7 @@ class Metadata;
namespace hlsl {
namespace rootsig {
+
class MetadataBuilder {
public:
MetadataBuilder(llvm::LLVMContext &Ctx, ArrayRef<RootElement> Elements)
diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
index f5e51978e5a4e..c09f51169c4ae 100644
--- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp
+++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
@@ -53,14 +53,14 @@ analyzeModule(Module &M) {
/** Root Signature are specified as following in the metadata:
- !dx.rootsignatures = !{!2} ; list of function/root signature pairs
- !2 = !{ ptr @main, !3 } ; function, root signature
- !3 = !{ !4, !5, !6, !7 } ; list of root signature elements
-
- So for each MDNode inside dx.rootsignatures NamedMDNode
- (the Root parameter of this function), the parsing process needs
- to loop through each of its operands and process the function,
- signature pair.
+ !dx.rootsignatures = !{!2} ; list of function/root signature pairs
+ !2 = !{ ptr @main, !3 } ; function, root signature
+ !3 = !{ !4, !5, !6, !7 } ; list of root signature elements
+
+ So for each MDNode inside dx.rootsignatures NamedMDNode
+ (the Root parameter of this function), the parsing process needs
+ to loop through each of its operands and process the function,
+ signature pair.
*/
LLVMContext *Ctx = &M.getContext();
diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MinLopBias.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MinLopBias.ll
index dc4cb8987e777..34b27eb40f5fb 100644
--- a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MinLopBias.ll
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MinLopBias.ll
@@ -3,7 +3,7 @@
target triple = "dxil-unknown-shadermodel6.0-compute"
-; CHECK: error: Invalid value for MipLODBias: 666
+; CHECK: error: Invalid value for MipLODBias: 666
; CHECK-NOT: Root Signature Definitions
define void @main() #0 {
>From 0c6ac32423d68cd5e18db553daf68b0fc503e923 Mon Sep 17 00:00:00 2001
From: joaosaffran <joao.saffran at microsoft.com>
Date: Thu, 17 Jul 2025 01:08:42 +0000
Subject: [PATCH 06/10] clean
---
llvm/lib/Target/DirectX/DXILRootSignature.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
index c09f51169c4ae..ebdfcaa566b51 100644
--- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp
+++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
@@ -61,7 +61,7 @@ analyzeModule(Module &M) {
(the Root parameter of this function), the parsing process needs
to loop through each of its operands and process the function,
signature pair.
-*/
+ */
LLVMContext *Ctx = &M.getContext();
>From ea96d91be1e437109e51b46bf968050ca862ded3 Mon Sep 17 00:00:00 2001
From: joaosaffran <joao.saffran at microsoft.com>
Date: Tue, 15 Jul 2025 21:27:04 +0000
Subject: [PATCH 07/10] Improve error handling
---
.../Frontend/HLSL/RootSignatureMetadata.h | 130 +++++--
.../Frontend/HLSL/RootSignatureMetadata.cpp | 345 ++++++++++--------
llvm/lib/Target/DirectX/DXILRootSignature.cpp | 33 +-
...ature-DescriptorTable-Invalid-RangeType.ll | 2 +-
.../RootSignature-Flags-Error.ll | 2 +-
...ure-RootDescriptor-Invalid-RegisterKind.ll | 2 +-
...Signature-StaticSamplers-Invalid-MaxLod.ll | 2 +-
...Signature-StaticSamplers-Invalid-MinLod.ll | 2 +-
...ature-StaticSamplers-Invalid-MinLopBias.ll | 2 +-
9 files changed, 324 insertions(+), 196 deletions(-)
diff --git a/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h b/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
index cd5966db42b5f..729ea22d3c8ab 100644
--- a/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
+++ b/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
@@ -26,6 +26,96 @@ class Metadata;
namespace hlsl {
namespace rootsig {
+inline std::optional<uint32_t> extractMdIntValue(MDNode *Node,
+ unsigned int OpId) {
+ if (auto *CI =
+ mdconst::dyn_extract<ConstantInt>(Node->getOperand(OpId).get()))
+ return CI->getZExtValue();
+ return std::nullopt;
+}
+
+inline std::optional<float> extractMdFloatValue(MDNode *Node,
+ unsigned int OpId) {
+ if (auto *CI = mdconst::dyn_extract<ConstantFP>(Node->getOperand(OpId).get()))
+ return CI->getValueAPF().convertToFloat();
+ return std::nullopt;
+}
+
+inline std::optional<StringRef> extractMdStringValue(MDNode *Node,
+ unsigned int OpId) {
+ MDString *NodeText = dyn_cast<MDString>(Node->getOperand(OpId));
+ if (NodeText == nullptr)
+ return std::nullopt;
+ return NodeText->getString();
+}
+
+template <typename T>
+class RootSignatureValidationError
+ : public ErrorInfo<RootSignatureValidationError<T>> {
+public:
+ static char ID;
+ std::string ParamName;
+ T Value;
+
+ RootSignatureValidationError(StringRef ParamName, T Value)
+ : ParamName(ParamName.str()), Value(Value) {}
+
+ void log(raw_ostream &OS) const override {
+ OS << "Invalid value for " << ParamName << ": " << Value;
+ }
+
+ std::error_code convertToErrorCode() const override {
+ return llvm::inconvertibleErrorCode();
+ }
+};
+
+class GenericRSMetadataError : public ErrorInfo<GenericRSMetadataError> {
+public:
+ static char ID;
+ std::string Message;
+
+ GenericRSMetadataError(Twine Message) : Message(Message.str()) {}
+
+ void log(raw_ostream &OS) const override { OS << Message; }
+
+ std::error_code convertToErrorCode() const override {
+ return llvm::inconvertibleErrorCode();
+ }
+};
+
+class InvalidRSMetadataFormat : public ErrorInfo<InvalidRSMetadataFormat> {
+public:
+ static char ID;
+ std::string ElementName;
+
+ InvalidRSMetadataFormat(StringRef ElementName)
+ : ElementName(ElementName.str()) {}
+
+ void log(raw_ostream &OS) const override {
+ OS << "Invalid format for " << ElementName;
+ }
+
+ std::error_code convertToErrorCode() const override {
+ return llvm::inconvertibleErrorCode();
+ }
+};
+
+class InvalidRSMetadataValue : public ErrorInfo<InvalidRSMetadataValue> {
+public:
+ static char ID;
+ std::string ParamName;
+
+ InvalidRSMetadataValue(StringRef ParamName) : ParamName(ParamName.str()) {}
+
+ void log(raw_ostream &OS) const override {
+ OS << "Invalid value for " << ParamName;
+ }
+
+ std::error_code convertToErrorCode() const override {
+ return llvm::inconvertibleErrorCode();
+ }
+};
+
class MetadataBuilder {
public:
MetadataBuilder(llvm::LLVMContext &Ctx, ArrayRef<RootElement> Elements)
@@ -67,29 +157,27 @@ class MetadataParser {
MetadataParser(MDNode *Root) : Root(Root) {}
/// Iterates through root signature and converts them into MapT
- LLVM_ABI bool ParseRootSignature(LLVMContext *Ctx,
- mcdxbc::RootSignatureDesc &RSD);
+ LLVM_ABI llvm::Expected<llvm::mcdxbc::RootSignatureDesc>
+ ParseRootSignature(uint32_t Version);
private:
- bool parseRootFlags(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
- MDNode *RootFlagNode);
- bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
- MDNode *RootConstantNode);
- bool parseRootDescriptors(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
- MDNode *RootDescriptorNode,
- RootSignatureElementKind ElementKind);
- bool parseDescriptorRange(LLVMContext *Ctx, mcdxbc::DescriptorTable &Table,
- MDNode *RangeDescriptorNode);
- bool parseDescriptorTable(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
- MDNode *DescriptorTableNode);
- bool parseRootSignatureElement(LLVMContext *Ctx,
- mcdxbc::RootSignatureDesc &RSD,
- MDNode *Element);
- bool parseStaticSampler(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
- MDNode *StaticSamplerNode);
-
- bool validateRootSignature(LLVMContext *Ctx,
- const llvm::mcdxbc::RootSignatureDesc &RSD);
+ llvm::Error parseRootFlags(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *RootFlagNode);
+ llvm::Error parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *RootConstantNode);
+ llvm::Error parseRootDescriptors(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *RootDescriptorNode,
+ RootSignatureElementKind ElementKind);
+ llvm::Error parseDescriptorRange(mcdxbc::DescriptorTable &Table,
+ MDNode *RangeDescriptorNode);
+ llvm::Error parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *DescriptorTableNode);
+ llvm::Error parseRootSignatureElement(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *Element);
+ llvm::Error parseStaticSampler(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *StaticSamplerNode);
+
+ llvm::Error validateRootSignature(const llvm::mcdxbc::RootSignatureDesc &RSD);
MDNode *Root;
};
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
index 53f59349ae029..23c1815d438ad 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
@@ -13,7 +13,6 @@
#include "llvm/Frontend/HLSL/RootSignatureMetadata.h"
#include "llvm/Frontend/HLSL/RootSignatureValidations.h"
-#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Metadata.h"
#include "llvm/Support/ScopedPrinter.h"
@@ -22,41 +21,11 @@ namespace llvm {
namespace hlsl {
namespace rootsig {
-static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
- unsigned int OpId) {
- if (auto *CI =
- mdconst::dyn_extract<ConstantInt>(Node->getOperand(OpId).get()))
- return CI->getZExtValue();
- return std::nullopt;
-}
+char GenericRSMetadataError::ID;
+char InvalidRSMetadataFormat::ID;
+char InvalidRSMetadataValue::ID;
-static std::optional<float> extractMdFloatValue(MDNode *Node,
- unsigned int OpId) {
- if (auto *CI = mdconst::dyn_extract<ConstantFP>(Node->getOperand(OpId).get()))
- return CI->getValueAPF().convertToFloat();
- return std::nullopt;
-}
-
-static std::optional<StringRef> extractMdStringValue(MDNode *Node,
- unsigned int OpId) {
- MDString *NodeText = dyn_cast<MDString>(Node->getOperand(OpId));
- if (NodeText == nullptr)
- return std::nullopt;
- return NodeText->getString();
-}
-
-static bool reportError(LLVMContext *Ctx, Twine Message,
- DiagnosticSeverity Severity = DS_Error) {
- Ctx->diagnose(DiagnosticInfoGeneric(Message, Severity));
- return true;
-}
-
-static bool reportValueError(LLVMContext *Ctx, Twine ParamName,
- uint32_t Value) {
- Ctx->diagnose(DiagnosticInfoGeneric(
- "Invalid value for " + ParamName + ": " + Twine(Value), DS_Error));
- return true;
-}
+template <typename T> char RootSignatureValidationError<T>::ID;
static const EnumEntry<dxil::ResourceClass> ResourceClassNames[] = {
{"CBV", dxil::ResourceClass::CBuffer},
@@ -227,27 +196,23 @@ MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) {
return MDNode::get(Ctx, Operands);
}
-bool MetadataParser::parseRootFlags(LLVMContext *Ctx,
- mcdxbc::RootSignatureDesc &RSD,
- MDNode *RootFlagNode) {
-
+llvm::Error MetadataParser::parseRootFlags(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *RootFlagNode) {
if (RootFlagNode->getNumOperands() != 2)
- return reportError(Ctx, "Invalid format for RootFlag Element");
+ return make_error<InvalidRSMetadataFormat>("RootFlag Element");
if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1))
RSD.Flags = *Val;
else
- return reportError(Ctx, "Invalid value for RootFlag");
+ return make_error<InvalidRSMetadataValue>("RootFlag");
- return false;
+ return llvm::Error::success();
}
-bool MetadataParser::parseRootConstants(LLVMContext *Ctx,
- mcdxbc::RootSignatureDesc &RSD,
- MDNode *RootConstantNode) {
-
+llvm::Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *RootConstantNode) {
if (RootConstantNode->getNumOperands() != 5)
- return reportError(Ctx, "Invalid format for RootConstants Element");
+ return make_error<InvalidRSMetadataFormat>("RootConstants Element");
dxbc::RTS0::v1::RootParameterHeader Header;
// The parameter offset doesn't matter here - we recalculate it during
@@ -258,39 +223,40 @@ bool MetadataParser::parseRootConstants(LLVMContext *Ctx,
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1))
Header.ShaderVisibility = *Val;
else
- return reportError(Ctx, "Invalid value for ShaderVisibility");
+ return make_error<InvalidRSMetadataValue>("ShaderVisibility");
dxbc::RTS0::v1::RootConstants Constants;
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2))
Constants.ShaderRegister = *Val;
else
- return reportError(Ctx, "Invalid value for ShaderRegister");
+ return make_error<InvalidRSMetadataValue>("ShaderRegister");
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3))
Constants.RegisterSpace = *Val;
else
- return reportError(Ctx, "Invalid value for RegisterSpace");
+ return make_error<InvalidRSMetadataValue>("RegisterSpace");
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4))
Constants.Num32BitValues = *Val;
else
- return reportError(Ctx, "Invalid value for Num32BitValues");
+ return make_error<InvalidRSMetadataValue>("Num32BitValues");
RSD.ParametersContainer.addParameter(Header, Constants);
- return false;
+ return llvm::Error::success();
}
-bool MetadataParser::parseRootDescriptors(
- LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
- MDNode *RootDescriptorNode, RootSignatureElementKind ElementKind) {
+llvm::Error
+MetadataParser::parseRootDescriptors(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *RootDescriptorNode,
+ RootSignatureElementKind ElementKind) {
assert(ElementKind == RootSignatureElementKind::SRV ||
ElementKind == RootSignatureElementKind::UAV ||
ElementKind == RootSignatureElementKind::CBV &&
- "parseRootDescriptors should only be called with RootDescriptor "
+ "parseRootDescriptors should only be called with RootDescriptor"
"element kind.");
if (RootDescriptorNode->getNumOperands() != 5)
- return reportError(Ctx, "Invalid format for Root Descriptor Element");
+ return make_error<InvalidRSMetadataFormat>("Root Descriptor Element");
dxbc::RTS0::v1::RootParameterHeader Header;
switch (ElementKind) {
@@ -311,40 +277,38 @@ bool MetadataParser::parseRootDescriptors(
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 1))
Header.ShaderVisibility = *Val;
else
- return reportError(Ctx, "Invalid value for ShaderVisibility");
+ return make_error<InvalidRSMetadataValue>("ShaderVisibility");
dxbc::RTS0::v2::RootDescriptor Descriptor;
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2))
Descriptor.ShaderRegister = *Val;
else
- return reportError(Ctx, "Invalid value for ShaderRegister");
+ return make_error<InvalidRSMetadataValue>("ShaderRegister");
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 3))
Descriptor.RegisterSpace = *Val;
else
- return reportError(Ctx, "Invalid value for RegisterSpace");
+ return make_error<InvalidRSMetadataValue>("RegisterSpace");
if (RSD.Version == 1) {
RSD.ParametersContainer.addParameter(Header, Descriptor);
- return false;
+ return llvm::Error::success();
}
assert(RSD.Version > 1);
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 4))
Descriptor.Flags = *Val;
else
- return reportError(Ctx, "Invalid value for Root Descriptor Flags");
+ return make_error<InvalidRSMetadataValue>("Root Descriptor Flags");
RSD.ParametersContainer.addParameter(Header, Descriptor);
- return false;
+ return llvm::Error::success();
}
-bool MetadataParser::parseDescriptorRange(LLVMContext *Ctx,
- mcdxbc::DescriptorTable &Table,
- MDNode *RangeDescriptorNode) {
-
+llvm::Error MetadataParser::parseDescriptorRange(mcdxbc::DescriptorTable &Table,
+ MDNode *RangeDescriptorNode) {
if (RangeDescriptorNode->getNumOperands() != 6)
- return reportError(Ctx, "Invalid format for Descriptor Range");
+ return make_error<InvalidRSMetadataFormat>("Descriptor Range");
dxbc::RTS0::v2::DescriptorRange Range;
@@ -352,7 +316,7 @@ bool MetadataParser::parseDescriptorRange(LLVMContext *Ctx,
extractMdStringValue(RangeDescriptorNode, 0);
if (!ElementText.has_value())
- return reportError(Ctx, "Descriptor Range, first element is not a string.");
+ return make_error<InvalidRSMetadataFormat>("Descriptor Range");
Range.RangeType =
StringSwitch<uint32_t>(*ElementText)
@@ -364,50 +328,50 @@ bool MetadataParser::parseDescriptorRange(LLVMContext *Ctx,
.Default(~0U);
if (Range.RangeType == ~0U)
- return reportError(Ctx, "Invalid Descriptor Range type: " + *ElementText);
+ return make_error<GenericRSMetadataError>("Invalid Descriptor Range type:" +
+ *ElementText);
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 1))
Range.NumDescriptors = *Val;
else
- return reportError(Ctx, "Invalid value for Number of Descriptor in Range");
+ return make_error<GenericRSMetadataError>("Number of Descriptor in Range");
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 2))
Range.BaseShaderRegister = *Val;
else
- return reportError(Ctx, "Invalid value for BaseShaderRegister");
+ return make_error<InvalidRSMetadataValue>("BaseShaderRegister");
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 3))
Range.RegisterSpace = *Val;
else
- return reportError(Ctx, "Invalid value for RegisterSpace");
+ return make_error<InvalidRSMetadataValue>("RegisterSpace");
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 4))
Range.OffsetInDescriptorsFromTableStart = *Val;
else
- return reportError(Ctx,
- "Invalid value for OffsetInDescriptorsFromTableStart");
+ return make_error<InvalidRSMetadataValue>(
+ "OffsetInDescriptorsFromTableStart");
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 5))
Range.Flags = *Val;
else
- return reportError(Ctx, "Invalid value for Descriptor Range Flags");
+ return make_error<InvalidRSMetadataValue>("Descriptor Range Flags");
Table.Ranges.push_back(Range);
- return false;
+ return llvm::Error::success();
}
-bool MetadataParser::parseDescriptorTable(LLVMContext *Ctx,
- mcdxbc::RootSignatureDesc &RSD,
- MDNode *DescriptorTableNode) {
+llvm::Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *DescriptorTableNode) {
const unsigned int NumOperands = DescriptorTableNode->getNumOperands();
if (NumOperands < 2)
- return reportError(Ctx, "Invalid format for Descriptor Table");
+ return make_error<InvalidRSMetadataFormat>("Descriptor Table");
dxbc::RTS0::v1::RootParameterHeader Header;
if (std::optional<uint32_t> Val = extractMdIntValue(DescriptorTableNode, 1))
Header.ShaderVisibility = *Val;
else
- return reportError(Ctx, "Invalid value for ShaderVisibility");
+ return make_error<InvalidRSMetadataValue>("ShaderVisibility");
mcdxbc::DescriptorTable Table;
Header.ParameterType =
@@ -416,98 +380,98 @@ bool MetadataParser::parseDescriptorTable(LLVMContext *Ctx,
for (unsigned int I = 2; I < NumOperands; I++) {
MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I));
if (Element == nullptr)
- return reportError(Ctx, "Missing Root Element Metadata Node.");
+ return make_error<GenericRSMetadataError>(
+ "Missing Root Element Metadata Node.");
- if (parseDescriptorRange(Ctx, Table, Element))
- return true;
+ if (auto Err = parseDescriptorRange(Table, Element))
+ return Err;
}
RSD.ParametersContainer.addParameter(Header, Table);
- return false;
+ return llvm::Error::success();
}
-bool MetadataParser::parseStaticSampler(LLVMContext *Ctx,
- mcdxbc::RootSignatureDesc &RSD,
- MDNode *StaticSamplerNode) {
+llvm::Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *StaticSamplerNode) {
if (StaticSamplerNode->getNumOperands() != 14)
- return reportError(Ctx, "Invalid format for Static Sampler");
+ return make_error<InvalidRSMetadataFormat>("Static Sampler");
dxbc::RTS0::v1::StaticSampler Sampler;
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 1))
Sampler.Filter = *Val;
else
- return reportError(Ctx, "Invalid value for Filter");
+ return make_error<InvalidRSMetadataValue>("Filter");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 2))
Sampler.AddressU = *Val;
else
- return reportError(Ctx, "Invalid value for AddressU");
+ return make_error<InvalidRSMetadataValue>("AddressU");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 3))
Sampler.AddressV = *Val;
else
- return reportError(Ctx, "Invalid value for AddressV");
+ return make_error<InvalidRSMetadataValue>("AddressV");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 4))
Sampler.AddressW = *Val;
else
- return reportError(Ctx, "Invalid value for AddressW");
+ return make_error<InvalidRSMetadataValue>("AddressW");
if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 5))
Sampler.MipLODBias = *Val;
else
- return reportError(Ctx, "Invalid value for MipLODBias");
+ return make_error<InvalidRSMetadataValue>("MipLODBias");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 6))
Sampler.MaxAnisotropy = *Val;
else
- return reportError(Ctx, "Invalid value for MaxAnisotropy");
+ return make_error<InvalidRSMetadataValue>("MaxAnisotropy");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 7))
Sampler.ComparisonFunc = *Val;
else
- return reportError(Ctx, "Invalid value for ComparisonFunc ");
+ return make_error<InvalidRSMetadataValue>("ComparisonFunc");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 8))
Sampler.BorderColor = *Val;
else
- return reportError(Ctx, "Invalid value for ComparisonFunc ");
+ return make_error<InvalidRSMetadataValue>("ComparisonFunc");
if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 9))
Sampler.MinLOD = *Val;
else
- return reportError(Ctx, "Invalid value for MinLOD");
+ return make_error<InvalidRSMetadataValue>("MinLOD");
if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 10))
Sampler.MaxLOD = *Val;
else
- return reportError(Ctx, "Invalid value for MaxLOD");
+ return make_error<InvalidRSMetadataValue>("MaxLOD");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 11))
Sampler.ShaderRegister = *Val;
else
- return reportError(Ctx, "Invalid value for ShaderRegister");
+ return make_error<InvalidRSMetadataValue>("ShaderRegister");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 12))
Sampler.RegisterSpace = *Val;
else
- return reportError(Ctx, "Invalid value for RegisterSpace");
+ return make_error<InvalidRSMetadataValue>("RegisterSpace");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 13))
Sampler.ShaderVisibility = *Val;
else
- return reportError(Ctx, "Invalid value for ShaderVisibility");
+ return make_error<InvalidRSMetadataValue>("ShaderVisibility");
RSD.StaticSamplers.push_back(Sampler);
- return false;
+ return llvm::Error::success();
}
-bool MetadataParser::parseRootSignatureElement(LLVMContext *Ctx,
- mcdxbc::RootSignatureDesc &RSD,
- MDNode *Element) {
+llvm::Error
+MetadataParser::parseRootSignatureElement(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *Element) {
std::optional<StringRef> ElementText = extractMdStringValue(Element, 0);
if (!ElementText.has_value())
- return reportError(Ctx, "Invalid format for Root Element");
+ return make_error<InvalidRSMetadataFormat>("Root Element");
RootSignatureElementKind ElementKind =
StringSwitch<RootSignatureElementKind>(*ElementText)
@@ -523,38 +487,48 @@ bool MetadataParser::parseRootSignatureElement(LLVMContext *Ctx,
switch (ElementKind) {
case RootSignatureElementKind::RootFlags:
- return parseRootFlags(Ctx, RSD, Element);
+ return parseRootFlags(RSD, Element);
case RootSignatureElementKind::RootConstants:
- return parseRootConstants(Ctx, RSD, Element);
+ return parseRootConstants(RSD, Element);
case RootSignatureElementKind::CBV:
case RootSignatureElementKind::SRV:
case RootSignatureElementKind::UAV:
- return parseRootDescriptors(Ctx, RSD, Element, ElementKind);
+ return parseRootDescriptors(RSD, Element, ElementKind);
case RootSignatureElementKind::DescriptorTable:
- return parseDescriptorTable(Ctx, RSD, Element);
+ return parseDescriptorTable(RSD, Element);
case RootSignatureElementKind::StaticSamplers:
- return parseStaticSampler(Ctx, RSD, Element);
+ return parseStaticSampler(RSD, Element);
case RootSignatureElementKind::Error:
- return reportError(Ctx, "Invalid Root Signature Element: " + *ElementText);
+ return make_error<GenericRSMetadataError>(
+ "Invalid Root Signature Element:" + *ElementText);
}
llvm_unreachable("Unhandled RootSignatureElementKind enum.");
}
-bool MetadataParser::validateRootSignature(
- LLVMContext *Ctx, const llvm::mcdxbc::RootSignatureDesc &RSD) {
+llvm::Error MetadataParser::validateRootSignature(
+ const llvm::mcdxbc::RootSignatureDesc &RSD) {
+ Error DeferredErrs = Error::success();
if (!llvm::hlsl::rootsig::verifyVersion(RSD.Version)) {
- return reportValueError(Ctx, "Version", RSD.Version);
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "Version", RSD.Version));
}
if (!llvm::hlsl::rootsig::verifyRootFlag(RSD.Flags)) {
- return reportValueError(Ctx, "RootFlags", RSD.Flags);
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "RootFlags", RSD.Flags));
}
for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) {
if (!dxbc::isValidShaderVisibility(Info.Header.ShaderVisibility))
- return reportValueError(Ctx, "ShaderVisibility",
- Info.Header.ShaderVisibility);
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "ShaderVisibility", Info.Header.ShaderVisibility));
assert(dxbc::isValidParameterType(Info.Header.ParameterType) &&
"Invalid value for ParameterType");
@@ -567,16 +541,24 @@ bool MetadataParser::validateRootSignature(
const dxbc::RTS0::v2::RootDescriptor &Descriptor =
RSD.ParametersContainer.getRootDescriptor(Info.Location);
if (!llvm::hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister))
- return reportValueError(Ctx, "ShaderRegister",
- Descriptor.ShaderRegister);
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "ShaderRegister", Descriptor.ShaderRegister));
if (!llvm::hlsl::rootsig::verifyRegisterSpace(Descriptor.RegisterSpace))
- return reportValueError(Ctx, "RegisterSpace", Descriptor.RegisterSpace);
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "RegisterSpace", Descriptor.RegisterSpace));
if (RSD.Version > 1) {
if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(RSD.Version,
Descriptor.Flags))
- return reportValueError(Ctx, "RootDescriptorFlag", Descriptor.Flags);
+ DeferredErrs = joinErrors(
+ std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "RootDescriptorFlag", Descriptor.Flags));
}
break;
}
@@ -585,17 +567,29 @@ bool MetadataParser::validateRootSignature(
RSD.ParametersContainer.getDescriptorTable(Info.Location);
for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) {
if (!llvm::hlsl::rootsig::verifyRangeType(Range.RangeType))
- return reportValueError(Ctx, "RangeType", Range.RangeType);
+ DeferredErrs = joinErrors(
+ std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "RangeType", Range.RangeType));
if (!llvm::hlsl::rootsig::verifyRegisterSpace(Range.RegisterSpace))
- return reportValueError(Ctx, "RegisterSpace", Range.RegisterSpace);
+ DeferredErrs = joinErrors(
+ std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "RegisterSpace", Range.RegisterSpace));
if (!llvm::hlsl::rootsig::verifyNumDescriptors(Range.NumDescriptors))
- return reportValueError(Ctx, "NumDescriptors", Range.NumDescriptors);
+ DeferredErrs = joinErrors(
+ std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "NumDescriptors", Range.NumDescriptors));
if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(
RSD.Version, Range.RangeType, Range.Flags))
- return reportValueError(Ctx, "DescriptorFlag", Range.Flags);
+ DeferredErrs = joinErrors(
+ std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "DescriptorFlag", Range.Flags));
}
break;
}
@@ -604,64 +598,111 @@ bool MetadataParser::validateRootSignature(
for (const dxbc::RTS0::v1::StaticSampler &Sampler : RSD.StaticSamplers) {
if (!llvm::hlsl::rootsig::verifySamplerFilter(Sampler.Filter))
- return reportValueError(Ctx, "Filter", Sampler.Filter);
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "Filter", Sampler.Filter));
if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressU))
- return reportValueError(Ctx, "AddressU", Sampler.AddressU);
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "AddressU", Sampler.AddressU));
if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressV))
- return reportValueError(Ctx, "AddressV", Sampler.AddressV);
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "AddressV", Sampler.AddressV));
if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressW))
- return reportValueError(Ctx, "AddressW", Sampler.AddressW);
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "AddressW", Sampler.AddressW));
if (!llvm::hlsl::rootsig::verifyMipLODBias(Sampler.MipLODBias))
- return reportValueError(Ctx, "MipLODBias", Sampler.MipLODBias);
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<float>>(
+ "MipLODBias", Sampler.MipLODBias));
if (!llvm::hlsl::rootsig::verifyMaxAnisotropy(Sampler.MaxAnisotropy))
- return reportValueError(Ctx, "MaxAnisotropy", Sampler.MaxAnisotropy);
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "MaxAnisotropy", Sampler.MaxAnisotropy));
if (!llvm::hlsl::rootsig::verifyComparisonFunc(Sampler.ComparisonFunc))
- return reportValueError(Ctx, "ComparisonFunc", Sampler.ComparisonFunc);
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "ComparisonFunc", Sampler.ComparisonFunc));
if (!llvm::hlsl::rootsig::verifyBorderColor(Sampler.BorderColor))
- return reportValueError(Ctx, "BorderColor", Sampler.BorderColor);
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "BorderColor", Sampler.BorderColor));
if (!llvm::hlsl::rootsig::verifyLOD(Sampler.MinLOD))
- return reportValueError(Ctx, "MinLOD", Sampler.MinLOD);
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<float>>(
+ "MinLOD", Sampler.MinLOD));
if (!llvm::hlsl::rootsig::verifyLOD(Sampler.MaxLOD))
- return reportValueError(Ctx, "MaxLOD", Sampler.MaxLOD);
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<float>>(
+ "MaxLOD", Sampler.MaxLOD));
if (!llvm::hlsl::rootsig::verifyRegisterValue(Sampler.ShaderRegister))
- return reportValueError(Ctx, "ShaderRegister", Sampler.ShaderRegister);
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "ShaderRegister", Sampler.ShaderRegister));
if (!llvm::hlsl::rootsig::verifyRegisterSpace(Sampler.RegisterSpace))
- return reportValueError(Ctx, "RegisterSpace", Sampler.RegisterSpace);
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "RegisterSpace", Sampler.RegisterSpace));
if (!dxbc::isValidShaderVisibility(Sampler.ShaderVisibility))
- return reportValueError(Ctx, "ShaderVisibility",
- Sampler.ShaderVisibility);
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "ShaderVisibility", Sampler.ShaderVisibility));
}
- return false;
+ return DeferredErrs;
}
-bool MetadataParser::ParseRootSignature(LLVMContext *Ctx,
- mcdxbc::RootSignatureDesc &RSD) {
- bool HasError = false;
-
- // Loop through the Root Elements of the root signature.
+llvm::Expected<mcdxbc::RootSignatureDesc>
+MetadataParser::ParseRootSignature(uint32_t Version) {
+ Error DeferredErrs = Error::success();
+ mcdxbc::RootSignatureDesc RSD;
+ RSD.Version = Version;
for (const auto &Operand : Root->operands()) {
MDNode *Element = dyn_cast<MDNode>(Operand);
if (Element == nullptr)
- return reportError(Ctx, "Missing Root Element Metadata Node.");
+ return joinErrors(std::move(DeferredErrs),
+ make_error<GenericRSMetadataError>(
+ "Missing Root Element Metadata Node."));
- HasError = HasError || parseRootSignatureElement(Ctx, RSD, Element) ||
- validateRootSignature(Ctx, RSD);
+ if (auto Err = parseRootSignatureElement(RSD, Element)) {
+ DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err));
+ }
}
- return HasError;
+ if (auto Err = validateRootSignature(RSD))
+ DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err));
+
+ if (DeferredErrs)
+ return std::move(DeferredErrs);
+
+ return std::move(RSD);
}
} // namespace rootsig
} // namespace hlsl
diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
index ebdfcaa566b51..924931ae0da5b 100644
--- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp
+++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
@@ -17,6 +17,7 @@
#include "llvm/Analysis/DXILMetadataAnalysis.h"
#include "llvm/BinaryFormat/DXContainer.h"
#include "llvm/Frontend/HLSL/RootSignatureMetadata.h"
+#include "llvm/Frontend/HLSL/RootSignatureMetadata.h"
#include "llvm/Frontend/HLSL/RootSignatureValidations.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DiagnosticInfo.h"
@@ -34,14 +35,6 @@
using namespace llvm;
using namespace llvm::dxil;
-static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
- unsigned int OpId) {
- if (auto *CI =
- mdconst::dyn_extract<ConstantInt>(Node->getOperand(OpId).get()))
- return CI->getZExtValue();
- return std::nullopt;
-}
-
static bool reportError(LLVMContext *Ctx, Twine Message,
DiagnosticSeverity Severity = DS_Error) {
Ctx->diagnose(DiagnosticInfoGeneric(Message, Severity));
@@ -111,14 +104,26 @@ analyzeModule(Module &M) {
reportError(Ctx, "Root Element is not a metadata node.");
continue;
}
- mcdxbc::RootSignatureDesc RSD;
- if (std::optional<uint32_t> Version = extractMdIntValue(RSDefNode, 2))
- RSD.Version = *Version;
+ uint32_t Version = 1;
+ if (std::optional<uint32_t> V =
+ llvm::hlsl::rootsig::extractMdIntValue(RSDefNode, 2))
+ Version = *V;
else {
reportError(Ctx, "Invalid RSDefNode value, expected constant int");
continue;
}
+ llvm::hlsl::rootsig::MetadataParser MDParser(RootElementListNode);
+ llvm::Expected<mcdxbc::RootSignatureDesc> RSDOrErr =
+ MDParser.ParseRootSignature(Version);
+
+ if (auto Err = RSDOrErr.takeError()) {
+ reportError(Ctx, toString(std::move(Err)));
+ continue;
+ }
+
+ auto &RSD = *RSDOrErr;
+
// Clang emits the root signature data in dxcontainer following a specific
// sequence. First the header, then the root parameters. So the header
// offset will always equal to the header size.
@@ -127,12 +132,6 @@ analyzeModule(Module &M) {
// static sampler offset is calculated when writting dxcontainer.
RSD.StaticSamplersOffset = 0u;
- hlsl::rootsig::MetadataParser MDParser(RootElementListNode);
-
- if (MDParser.ParseRootSignature(Ctx, RSD)) {
- return RSDMap;
- }
-
RSDMap.insert(std::make_pair(F, RSD));
}
diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-Invalid-RangeType.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-Invalid-RangeType.ll
index 0f7116307c315..644e4e4348980 100644
--- a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-Invalid-RangeType.ll
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-Invalid-RangeType.ll
@@ -2,7 +2,7 @@
target triple = "dxil-unknown-shadermodel6.0-compute"
-; CHECK: error: Invalid Descriptor Range type: Invalid
+; CHECK: error: Invalid Descriptor Range type:Invalid
; CHECK-NOT: Root Signature Definitions
define void @main() #0 {
diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags-Error.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags-Error.ll
index 65511160f230d..41e97701dcc20 100644
--- a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags-Error.ll
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags-Error.ll
@@ -2,7 +2,7 @@
target triple = "dxil-unknown-shadermodel6.0-compute"
-; CHECK: error: Invalid Root Signature Element: NOTRootFlags
+; CHECK: error: Invalid Root Signature Element:NOTRootFlags
; CHECK-NOT: Root Signature Definitions
diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootDescriptor-Invalid-RegisterKind.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootDescriptor-Invalid-RegisterKind.ll
index 579528d8b5e13..9e63b06674ebc 100644
--- a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootDescriptor-Invalid-RegisterKind.ll
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootDescriptor-Invalid-RegisterKind.ll
@@ -3,7 +3,7 @@
target triple = "dxil-unknown-shadermodel6.0-compute"
-; CHECK: error: Invalid Root Signature Element: Invalid
+; CHECK: error: Invalid Root Signature Element:Invalid
; CHECK-NOT: Root Signature Definitions
define void @main() #0 {
entry:
diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MaxLod.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MaxLod.ll
index 7e7d56eff309c..855e0c0cb6e51 100644
--- a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MaxLod.ll
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MaxLod.ll
@@ -3,7 +3,7 @@
target triple = "dxil-unknown-shadermodel6.0-compute"
-; CHECK: error: Invalid value for MaxLOD: 0
+; CHECK: error: Invalid value for MaxLOD: nan
; CHECK-NOT: Root Signature Definitions
define void @main() #0 {
diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MinLod.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MinLod.ll
index d958f10d3c1af..812749b9ed824 100644
--- a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MinLod.ll
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MinLod.ll
@@ -3,7 +3,7 @@
target triple = "dxil-unknown-shadermodel6.0-compute"
-; CHECK: error: Invalid value for MinLOD: 0
+; CHECK: error: Invalid value for MinLOD: nan
; CHECK-NOT: Root Signature Definitions
define void @main() #0 {
diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MinLopBias.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MinLopBias.ll
index 34b27eb40f5fb..6898aec6f2e49 100644
--- a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MinLopBias.ll
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MinLopBias.ll
@@ -3,7 +3,7 @@
target triple = "dxil-unknown-shadermodel6.0-compute"
-; CHECK: error: Invalid value for MipLODBias: 666
+; CHECK: error: Invalid value for MipLODBias: 6.660000e+02
; CHECK-NOT: Root Signature Definitions
define void @main() #0 {
>From 0c047d65c8dd707e799219d485a021ece6014e73 Mon Sep 17 00:00:00 2001
From: joaosaffran <joao.saffran at microsoft.com>
Date: Thu, 17 Jul 2025 01:30:14 +0000
Subject: [PATCH 08/10] clean
---
.../Frontend/HLSL/RootSignatureMetadata.h | 23 ------------------
.../Frontend/HLSL/RootSignatureMetadata.cpp | 24 ++++++++++++++++++-
llvm/lib/Target/DirectX/DXILRootSignature.cpp | 13 ++++++----
3 files changed, 32 insertions(+), 28 deletions(-)
diff --git a/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h b/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
index 729ea22d3c8ab..b3705a2132021 100644
--- a/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
+++ b/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
@@ -26,29 +26,6 @@ class Metadata;
namespace hlsl {
namespace rootsig {
-inline std::optional<uint32_t> extractMdIntValue(MDNode *Node,
- unsigned int OpId) {
- if (auto *CI =
- mdconst::dyn_extract<ConstantInt>(Node->getOperand(OpId).get()))
- return CI->getZExtValue();
- return std::nullopt;
-}
-
-inline std::optional<float> extractMdFloatValue(MDNode *Node,
- unsigned int OpId) {
- if (auto *CI = mdconst::dyn_extract<ConstantFP>(Node->getOperand(OpId).get()))
- return CI->getValueAPF().convertToFloat();
- return std::nullopt;
-}
-
-inline std::optional<StringRef> extractMdStringValue(MDNode *Node,
- unsigned int OpId) {
- MDString *NodeText = dyn_cast<MDString>(Node->getOperand(OpId));
- if (NodeText == nullptr)
- return std::nullopt;
- return NodeText->getString();
-}
-
template <typename T>
class RootSignatureValidationError
: public ErrorInfo<RootSignatureValidationError<T>> {
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
index 23c1815d438ad..41c23ecb692ea 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
@@ -24,9 +24,31 @@ namespace rootsig {
char GenericRSMetadataError::ID;
char InvalidRSMetadataFormat::ID;
char InvalidRSMetadataValue::ID;
-
template <typename T> char RootSignatureValidationError<T>::ID;
+inline std::optional<uint32_t> extractMdIntValue(MDNode *Node,
+ unsigned int OpId) {
+ if (auto *CI =
+ mdconst::dyn_extract<ConstantInt>(Node->getOperand(OpId).get()))
+ return CI->getZExtValue();
+ return std::nullopt;
+}
+
+inline std::optional<float> extractMdFloatValue(MDNode *Node,
+ unsigned int OpId) {
+ if (auto *CI = mdconst::dyn_extract<ConstantFP>(Node->getOperand(OpId).get()))
+ return CI->getValueAPF().convertToFloat();
+ return std::nullopt;
+}
+
+inline std::optional<StringRef> extractMdStringValue(MDNode *Node,
+ unsigned int OpId) {
+ MDString *NodeText = dyn_cast<MDString>(Node->getOperand(OpId));
+ if (NodeText == nullptr)
+ return std::nullopt;
+ return NodeText->getString();
+}
+
static const EnumEntry<dxil::ResourceClass> ResourceClassNames[] = {
{"CBV", dxil::ResourceClass::CBuffer},
{"SRV", dxil::ResourceClass::SRV},
diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
index 924931ae0da5b..712fe8c958e5a 100644
--- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp
+++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
@@ -17,8 +17,6 @@
#include "llvm/Analysis/DXILMetadataAnalysis.h"
#include "llvm/BinaryFormat/DXContainer.h"
#include "llvm/Frontend/HLSL/RootSignatureMetadata.h"
-#include "llvm/Frontend/HLSL/RootSignatureMetadata.h"
-#include "llvm/Frontend/HLSL/RootSignatureValidations.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/Function.h"
@@ -35,6 +33,14 @@
using namespace llvm;
using namespace llvm::dxil;
+static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
+ unsigned int OpId) {
+ if (auto *CI =
+ mdconst::dyn_extract<ConstantInt>(Node->getOperand(OpId).get()))
+ return CI->getZExtValue();
+ return std::nullopt;
+}
+
static bool reportError(LLVMContext *Ctx, Twine Message,
DiagnosticSeverity Severity = DS_Error) {
Ctx->diagnose(DiagnosticInfoGeneric(Message, Severity));
@@ -105,8 +111,7 @@ analyzeModule(Module &M) {
continue;
}
uint32_t Version = 1;
- if (std::optional<uint32_t> V =
- llvm::hlsl::rootsig::extractMdIntValue(RSDefNode, 2))
+ if (std::optional<uint32_t> V = extractMdIntValue(RSDefNode, 2))
Version = *V;
else {
reportError(Ctx, "Invalid RSDefNode value, expected constant int");
>From 33d14186510cb8f1c04e2a9479fd92212ff565cb Mon Sep 17 00:00:00 2001
From: joaosaffran <joao.saffran at microsoft.com>
Date: Tue, 22 Jul 2025 01:42:13 +0000
Subject: [PATCH 09/10] address comments
---
llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h | 1 -
1 file changed, 1 deletion(-)
diff --git a/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h b/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
index cd5966db42b5f..6fa51eded52f0 100644
--- a/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
+++ b/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
@@ -66,7 +66,6 @@ class MetadataParser {
public:
MetadataParser(MDNode *Root) : Root(Root) {}
- /// Iterates through root signature and converts them into MapT
LLVM_ABI bool ParseRootSignature(LLVMContext *Ctx,
mcdxbc::RootSignatureDesc &RSD);
>From f2845f2b19b6f369a846f25976b8d5ef38f53382 Mon Sep 17 00:00:00 2001
From: joaosaffran <joao.saffran at microsoft.com>
Date: Tue, 22 Jul 2025 06:26:24 +0000
Subject: [PATCH 10/10] address comments
---
llvm/lib/Target/DirectX/DXILRootSignature.cpp | 8 +++-----
1 file changed, 3 insertions(+), 5 deletions(-)
diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
index 712fe8c958e5a..97634c1f3cb9c 100644
--- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp
+++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
@@ -110,17 +110,15 @@ analyzeModule(Module &M) {
reportError(Ctx, "Root Element is not a metadata node.");
continue;
}
- uint32_t Version = 1;
- if (std::optional<uint32_t> V = extractMdIntValue(RSDefNode, 2))
- Version = *V;
- else {
+ std::optional<uint32_t> V = extractMdIntValue(RSDefNode, 2);
+ if (!V.has_value()) {
reportError(Ctx, "Invalid RSDefNode value, expected constant int");
continue;
}
llvm::hlsl::rootsig::MetadataParser MDParser(RootElementListNode);
llvm::Expected<mcdxbc::RootSignatureDesc> RSDOrErr =
- MDParser.ParseRootSignature(Version);
+ MDParser.ParseRootSignature(V.value());
if (auto Err = RSDOrErr.takeError()) {
reportError(Ctx, toString(std::move(Err)));
More information about the llvm-commits
mailing list