[llvm-branch-commits] [llvm] [DirectX] Improve error accumulation in root signature parsing (PR #144465)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Jun 17 11:15:32 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-directx
Author: None (joaosaffran)
<details>
<summary>Changes</summary>
This patch enhances error handling in the DirectX backend's root signature
parsing, specifically in DXILRootSignature.cpp. The changes include:
1. Modify error handling to accumulate errors:
- Replace early returns with error accumulation using HasError
- Allow validation to continue after encountering an invalid type
- Maintain original error reporting functionality while collecting multiple errors
2. Fix root flag parsing:
- Use boolean accumulator for multiple validation errors
- Improve invalid type reporting for root flag nodes
- Maintain consistency with existing error reporting patterns
Before this change, the parser would stop at the first error encountered. Now it
continues validation, collecting all errors before returning. This provides a better
developer experience by showing all issues that need to be fixed at once.
Example of changes:
```cpp
bool HasError = false;
if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1))
RSD.Flags = *Val;
else
HasError = HasError || reportInvalidTypeError<ConstantInt>(
Ctx, "RootFlagNode", RootFlagNode, 1);
return HasError;
```
Testing:
- All existing DirectX backend tests pass
- Verified error accumulation with multiple validation failures
- Root signature parsing continues to work as expected
---
Patch is 25.66 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144465.diff
2 Files Affected:
- (modified) llvm/lib/Target/DirectX/DXILRootSignature.cpp (+171-110)
- (added) llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Error-Accumulation.ll (+23)
``````````diff
diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
index 57d5ee8ac467c..eea46e714b756 100644
--- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp
+++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
@@ -141,14 +141,15 @@ static bool parseRootFlags(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
if (RootFlagNode->getNumOperands() != 2)
return reportError(Ctx, "Invalid format for RootFlag Element");
-
+ bool HasError = false;
if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1))
RSD.Flags = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "RootFlagNode",
- RootFlagNode, 1);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootFlagNode",
+ RootFlagNode, 1) ||
+ HasError;
- return false;
+ return HasError;
}
static bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
@@ -157,6 +158,7 @@ static bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
if (RootConstantNode->getNumOperands() != 5)
return reportError(Ctx, "Invalid format for RootConstants Element");
+ bool HasError = false;
dxbc::RTS0::v1::RootParameterHeader Header;
// The parameter offset doesn't matter here - we recalculate it during
// serialization Header.ParameterOffset = 0;
@@ -166,31 +168,35 @@ static bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1))
Header.ShaderVisibility = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode",
- RootConstantNode, 1);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode",
+ RootConstantNode, 1) ||
+ HasError;
dxbc::RTS0::v1::RootConstants Constants;
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2))
Constants.ShaderRegister = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode",
- RootConstantNode, 2);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode",
+ RootConstantNode, 2) ||
+ HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3))
Constants.RegisterSpace = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode",
- RootConstantNode, 3);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode",
+ RootConstantNode, 3) ||
+ HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4))
Constants.Num32BitValues = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode",
- RootConstantNode, 4);
-
- RSD.ParametersContainer.addParameter(Header, Constants);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode",
+ RootConstantNode, 4) ||
+ HasError;
+ if (!HasError)
+ RSD.ParametersContainer.addParameter(Header, Constants);
- return false;
+ return HasError;
}
static bool parseRootDescriptors(LLVMContext *Ctx,
@@ -205,6 +211,7 @@ static bool parseRootDescriptors(LLVMContext *Ctx,
if (RootDescriptorNode->getNumOperands() != 5)
return reportError(Ctx, "Invalid format for Root Descriptor Element");
+ bool HasError = false;
dxbc::RTS0::v1::RootParameterHeader Header;
switch (ElementKind) {
case RootSignatureElementKind::SRV:
@@ -224,36 +231,41 @@ static bool parseRootDescriptors(LLVMContext *Ctx,
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 1))
Header.ShaderVisibility = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode",
- RootDescriptorNode, 1);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode",
+ RootDescriptorNode, 1) ||
+ HasError;
dxbc::RTS0::v2::RootDescriptor Descriptor;
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2))
Descriptor.ShaderRegister = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode",
- RootDescriptorNode, 2);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode",
+ RootDescriptorNode, 2) ||
+ HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 3))
Descriptor.RegisterSpace = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode",
- RootDescriptorNode, 3);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode",
+ RootDescriptorNode, 3) ||
+ HasError;
if (RSD.Version == 1) {
- RSD.ParametersContainer.addParameter(Header, Descriptor);
- return false;
+ if (!HasError)
+ RSD.ParametersContainer.addParameter(Header, Descriptor);
+ return HasError;
}
assert(RSD.Version > 1);
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 4))
Descriptor.Flags = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode",
- RootDescriptorNode, 4);
-
- RSD.ParametersContainer.addParameter(Header, Descriptor);
- return false;
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode",
+ RootDescriptorNode, 4) ||
+ HasError;
+ if (!HasError)
+ RSD.ParametersContainer.addParameter(Header, Descriptor);
+ return HasError;
}
static bool parseDescriptorRange(LLVMContext *Ctx,
@@ -264,14 +276,16 @@ static bool parseDescriptorRange(LLVMContext *Ctx,
if (RangeDescriptorNode->getNumOperands() != 6)
return reportError(Ctx, "Invalid format for Descriptor Range");
+ bool HasError = false;
dxbc::RTS0::v2::DescriptorRange Range;
std::optional<StringRef> ElementText =
extractMdStringValue(RangeDescriptorNode, 0);
if (!ElementText.has_value())
- return reportInvalidTypeError<MDString>(Ctx, "RangeDescriptorNode",
- RangeDescriptorNode, 0);
+ HasError = reportInvalidTypeError<MDString>(Ctx, "RangeDescriptorNode",
+ RangeDescriptorNode, 0) ||
+ HasError;
Range.RangeType =
StringSwitch<uint32_t>(*ElementText)
@@ -283,40 +297,47 @@ static bool parseDescriptorRange(LLVMContext *Ctx,
.Default(-1u);
if (Range.RangeType == -1u)
- return reportError(Ctx, "Invalid Descriptor Range type: " + *ElementText);
+ HasError =
+ reportError(Ctx, "Invalid Descriptor Range type: " + *ElementText) ||
+ HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 1))
Range.NumDescriptors = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
- RangeDescriptorNode, 1);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
+ RangeDescriptorNode, 1) ||
+ HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 2))
Range.BaseShaderRegister = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
- RangeDescriptorNode, 2);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
+ RangeDescriptorNode, 2) ||
+ HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 3))
Range.RegisterSpace = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
- RangeDescriptorNode, 3);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
+ RangeDescriptorNode, 3) ||
+ HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 4))
Range.OffsetInDescriptorsFromTableStart = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
- RangeDescriptorNode, 4);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
+ RangeDescriptorNode, 4) ||
+ HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 5))
Range.Flags = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
- RangeDescriptorNode, 5);
-
- Table.Ranges.push_back(Range);
- return false;
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
+ RangeDescriptorNode, 5) ||
+ HasError;
+ if (!HasError)
+ Table.Ranges.push_back(Range);
+ return HasError;
}
static bool parseDescriptorTable(LLVMContext *Ctx,
@@ -325,13 +346,14 @@ static bool parseDescriptorTable(LLVMContext *Ctx,
const unsigned int NumOperands = DescriptorTableNode->getNumOperands();
if (NumOperands < 2)
return reportError(Ctx, "Invalid format for Descriptor Table");
-
+ bool HasError = false;
dxbc::RTS0::v1::RootParameterHeader Header;
if (std::optional<uint32_t> Val = extractMdIntValue(DescriptorTableNode, 1))
Header.ShaderVisibility = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "DescriptorTableNode",
- DescriptorTableNode, 1);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "DescriptorTableNode",
+ DescriptorTableNode, 1) ||
+ HasError;
mcdxbc::DescriptorTable Table;
Header.ParameterType =
@@ -340,15 +362,16 @@ static bool parseDescriptorTable(LLVMContext *Ctx,
for (unsigned int I = 2; I < NumOperands; I++) {
MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I));
if (Element == nullptr)
- return reportInvalidTypeError<MDNode>(Ctx, "DescriptorTableNode",
- DescriptorTableNode, I);
+ HasError = reportInvalidTypeError<MDNode>(Ctx, "DescriptorTableNode",
+ DescriptorTableNode, I) ||
+ HasError;
if (parseDescriptorRange(Ctx, RSD, Table, Element))
- return true;
+ HasError = true || HasError;
}
-
- RSD.ParametersContainer.addParameter(Header, Table);
- return false;
+ if (!HasError)
+ RSD.ParametersContainer.addParameter(Header, Table);
+ return HasError;
}
static bool parseStaticSampler(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
@@ -356,87 +379,101 @@ static bool parseStaticSampler(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
if (StaticSamplerNode->getNumOperands() != 14)
return reportError(Ctx, "Invalid format for Static Sampler");
+ bool HasError = false;
dxbc::RTS0::v1::StaticSampler Sampler;
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 1))
Sampler.Filter = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
- StaticSamplerNode, 1);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 1) ||
+ HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 2))
Sampler.AddressU = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
- StaticSamplerNode, 2);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 2) ||
+ HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 3))
Sampler.AddressV = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
- StaticSamplerNode, 3);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 3) ||
+ HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 4))
Sampler.AddressW = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
- StaticSamplerNode, 4);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 4) ||
+ HasError;
if (std::optional<APFloat> Val = extractMdFloatValue(StaticSamplerNode, 5))
Sampler.MipLODBias = Val->convertToFloat();
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
- StaticSamplerNode, 5);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 5) ||
+ HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 6))
Sampler.MaxAnisotropy = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
- StaticSamplerNode, 6);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 6) ||
+ HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 7))
Sampler.ComparisonFunc = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
- StaticSamplerNode, 7);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 7) ||
+ HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 8))
Sampler.BorderColor = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
- StaticSamplerNode, 8);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 8) ||
+ HasError;
if (std::optional<APFloat> Val = extractMdFloatValue(StaticSamplerNode, 9))
Sampler.MinLOD = Val->convertToFloat();
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
- StaticSamplerNode, 9);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 9) ||
+ HasError;
if (std::optional<APFloat> Val = extractMdFloatValue(StaticSamplerNode, 10))
Sampler.MaxLOD = Val->convertToFloat();
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
- StaticSamplerNode, 10);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 10) ||
+ HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 11))
Sampler.ShaderRegister = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
- StaticSamplerNode, 11);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 11) ||
+ HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 12))
Sampler.RegisterSpace = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
- StaticSamplerNode, 12);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 12) ||
+ HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 13))
Sampler.ShaderVisibility = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
- StaticSamplerNode, 13);
-
- RSD.StaticSamplers.push_back(Sampler);
- return false;
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 13) ||
+ HasError;
+ if (!HasError)
+ RSD.StaticSamplers.push_back(Sampler);
+ return HasError;
}
static bool parseRootSignatureElement(LLVMContext *Ctx,
@@ -488,7 +525,7 @@ static bool parse(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
if (Element == nullptr)
return reportError(Ctx, "Missing Root Element Metadata Node.");
- HasError = HasError || parseRootSignatureElement(Ctx, RSD, Element);
+ HasError = parseRootSignatureElement(Ctx, RSD, Element) || HasError;
}
return HasError;
@@ -699,19 +736,20 @@ static bool verifyBorderColor(uint32_t BorderColor) {
static bool verifyLOD(float LOD) { return !std::isnan(LOD); }
static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) {
-
+ bool HasError = false;
if (!verifyVersion(RSD.Version)) {
- return reportValueError(Ctx, "Version", RSD.Version);
+ HasError = reportValueError(Ctx, "Version", RSD.Version) || HasError;
}
if (!verifyRootFlag(RSD.Flags)) {
- return reportValueError(Ctx, "RootFlags", RSD.Flags);
+ HasError = reportValueError(Ctx, "RootFlags", RSD.Flags) || HasError;
}
for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) {
if (!dxbc::isValidShaderVisibility(Info.Header.ShaderVisibility))
- return reportValueError(Ctx, "ShaderVisibility",
- Info.Header.ShaderVisibility);
+ HasError = reportValueError(Ctx, "ShaderVisibility",
+ Info.Header.ShaderVisibility) ||
+ HasError;
assert(dxbc::isValidParameterType(Info.Header.ParameterType) &&
"Invalid value for ParameterType");
@@ -724,15 +762,20 @@ static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) {
const dxbc::RTS0::v2::RootDescriptor &Descriptor =
RSD.ParametersContainer.getRootDescriptor(Info.Location);
if (!verifyRegisterValue(Descriptor.ShaderRegister))
- return reportValueError(Ctx, "ShaderRegister",
- Descriptor.ShaderRegister);
+ HasError = reportValueError(Ctx, "ShaderRegister",
+ Descriptor.ShaderRegister) ||
+ HasError;
if (!verifyRegisterSpace(Descriptor.RegisterSpace))
- return reportValueError(Ctx, "RegisterSpace", Descriptor.RegisterSpace);
+ HasError =
+ reportValueError(Ctx, "RegisterSpace", Descriptor.RegisterSpace) ||
+ HasError;
if (RSD.Version > 1)...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/144465
More information about the llvm-branch-commits
mailing list