[llvm-branch-commits] [llvm] [DirectX] Improve error accumulation in root signature parsing (PR #144465)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Jun 17 11:15:32 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-directx

Author: None (joaosaffran)

<details>
<summary>Changes</summary>

This patch enhances error handling in the DirectX backend's root signature 
parsing, specifically in DXILRootSignature.cpp. The changes include:

1. Modify error handling to accumulate errors:
   - Replace early returns with error accumulation using HasError
   - Allow validation to continue after encountering an invalid type
   - Maintain original error reporting functionality while collecting multiple errors

2. Fix root flag parsing:
   - Use boolean accumulator for multiple validation errors
   - Improve invalid type reporting for root flag nodes
   - Maintain consistency with existing error reporting patterns

Before this change, the parser would stop at the first error encountered. Now it 
continues validation, collecting all errors before returning. This provides a better
developer experience by showing all issues that need to be fixed at once.

Example of changes:
```cpp
bool HasError = false;
if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1))
  RSD.Flags = *Val;
else
  HasError = HasError || reportInvalidTypeError<ConstantInt>(
                            Ctx, "RootFlagNode", RootFlagNode, 1);
return HasError;
```

Testing:
- All existing DirectX backend tests pass
- Verified error accumulation with multiple validation failures
- Root signature parsing continues to work as expected

---

Patch is 25.66 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144465.diff


2 Files Affected:

- (modified) llvm/lib/Target/DirectX/DXILRootSignature.cpp (+171-110) 
- (added) llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Error-Accumulation.ll (+23) 


``````````diff
diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
index 57d5ee8ac467c..eea46e714b756 100644
--- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp
+++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
@@ -141,14 +141,15 @@ static bool parseRootFlags(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
 
   if (RootFlagNode->getNumOperands() != 2)
     return reportError(Ctx, "Invalid format for RootFlag Element");
-
+  bool HasError = false;
   if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1))
     RSD.Flags = *Val;
   else
-    return reportInvalidTypeError<ConstantInt>(Ctx, "RootFlagNode",
-                                               RootFlagNode, 1);
+    HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootFlagNode",
+                                                   RootFlagNode, 1) ||
+               HasError;
 
-  return false;
+  return HasError;
 }
 
 static bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
@@ -157,6 +158,7 @@ static bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
   if (RootConstantNode->getNumOperands() != 5)
     return reportError(Ctx, "Invalid format for RootConstants Element");
 
+  bool HasError = false;
   dxbc::RTS0::v1::RootParameterHeader Header;
   // The parameter offset doesn't matter here - we recalculate it during
   // serialization  Header.ParameterOffset = 0;
@@ -166,31 +168,35 @@ static bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
   if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1))
     Header.ShaderVisibility = *Val;
   else
-    return reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode",
-                                               RootConstantNode, 1);
+    HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode",
+                                                   RootConstantNode, 1) ||
+               HasError;
 
   dxbc::RTS0::v1::RootConstants Constants;
   if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2))
     Constants.ShaderRegister = *Val;
   else
-    return reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode",
-                                               RootConstantNode, 2);
+    HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode",
+                                                   RootConstantNode, 2) ||
+               HasError;
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3))
     Constants.RegisterSpace = *Val;
   else
-    return reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode",
-                                               RootConstantNode, 3);
+    HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode",
+                                                   RootConstantNode, 3) ||
+               HasError;
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4))
     Constants.Num32BitValues = *Val;
   else
-    return reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode",
-                                               RootConstantNode, 4);
-
-  RSD.ParametersContainer.addParameter(Header, Constants);
+    HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode",
+                                                   RootConstantNode, 4) ||
+               HasError;
+  if (!HasError)
+    RSD.ParametersContainer.addParameter(Header, Constants);
 
-  return false;
+  return HasError;
 }
 
 static bool parseRootDescriptors(LLVMContext *Ctx,
@@ -205,6 +211,7 @@ static bool parseRootDescriptors(LLVMContext *Ctx,
   if (RootDescriptorNode->getNumOperands() != 5)
     return reportError(Ctx, "Invalid format for Root Descriptor Element");
 
+  bool HasError = false;
   dxbc::RTS0::v1::RootParameterHeader Header;
   switch (ElementKind) {
   case RootSignatureElementKind::SRV:
@@ -224,36 +231,41 @@ static bool parseRootDescriptors(LLVMContext *Ctx,
   if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 1))
     Header.ShaderVisibility = *Val;
   else
-    return reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode",
-                                               RootDescriptorNode, 1);
+    HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode",
+                                                   RootDescriptorNode, 1) ||
+               HasError;
 
   dxbc::RTS0::v2::RootDescriptor Descriptor;
   if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2))
     Descriptor.ShaderRegister = *Val;
   else
-    return reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode",
-                                               RootDescriptorNode, 2);
+    HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode",
+                                                   RootDescriptorNode, 2) ||
+               HasError;
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 3))
     Descriptor.RegisterSpace = *Val;
   else
-    return reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode",
-                                               RootDescriptorNode, 3);
+    HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode",
+                                                   RootDescriptorNode, 3) ||
+               HasError;
 
   if (RSD.Version == 1) {
-    RSD.ParametersContainer.addParameter(Header, Descriptor);
-    return false;
+    if (!HasError)
+      RSD.ParametersContainer.addParameter(Header, Descriptor);
+    return HasError;
   }
   assert(RSD.Version > 1);
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 4))
     Descriptor.Flags = *Val;
   else
-    return reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode",
-                                               RootDescriptorNode, 4);
-
-  RSD.ParametersContainer.addParameter(Header, Descriptor);
-  return false;
+    HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode",
+                                                   RootDescriptorNode, 4) ||
+               HasError;
+  if (!HasError)
+    RSD.ParametersContainer.addParameter(Header, Descriptor);
+  return HasError;
 }
 
 static bool parseDescriptorRange(LLVMContext *Ctx,
@@ -264,14 +276,16 @@ static bool parseDescriptorRange(LLVMContext *Ctx,
   if (RangeDescriptorNode->getNumOperands() != 6)
     return reportError(Ctx, "Invalid format for Descriptor Range");
 
+  bool HasError = false;
   dxbc::RTS0::v2::DescriptorRange Range;
 
   std::optional<StringRef> ElementText =
       extractMdStringValue(RangeDescriptorNode, 0);
 
   if (!ElementText.has_value())
-    return reportInvalidTypeError<MDString>(Ctx, "RangeDescriptorNode",
-                                            RangeDescriptorNode, 0);
+    HasError = reportInvalidTypeError<MDString>(Ctx, "RangeDescriptorNode",
+                                                RangeDescriptorNode, 0) ||
+               HasError;
 
   Range.RangeType =
       StringSwitch<uint32_t>(*ElementText)
@@ -283,40 +297,47 @@ static bool parseDescriptorRange(LLVMContext *Ctx,
           .Default(-1u);
 
   if (Range.RangeType == -1u)
-    return reportError(Ctx, "Invalid Descriptor Range type: " + *ElementText);
+    HasError =
+        reportError(Ctx, "Invalid Descriptor Range type: " + *ElementText) ||
+        HasError;
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 1))
     Range.NumDescriptors = *Val;
   else
-    return reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
-                                               RangeDescriptorNode, 1);
+    HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
+                                                   RangeDescriptorNode, 1) ||
+               HasError;
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 2))
     Range.BaseShaderRegister = *Val;
   else
-    return reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
-                                               RangeDescriptorNode, 2);
+    HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
+                                                   RangeDescriptorNode, 2) ||
+               HasError;
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 3))
     Range.RegisterSpace = *Val;
   else
-    return reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
-                                               RangeDescriptorNode, 3);
+    HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
+                                                   RangeDescriptorNode, 3) ||
+               HasError;
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 4))
     Range.OffsetInDescriptorsFromTableStart = *Val;
   else
-    return reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
-                                               RangeDescriptorNode, 4);
+    HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
+                                                   RangeDescriptorNode, 4) ||
+               HasError;
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 5))
     Range.Flags = *Val;
   else
-    return reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
-                                               RangeDescriptorNode, 5);
-
-  Table.Ranges.push_back(Range);
-  return false;
+    HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
+                                                   RangeDescriptorNode, 5) ||
+               HasError;
+  if (!HasError)
+    Table.Ranges.push_back(Range);
+  return HasError;
 }
 
 static bool parseDescriptorTable(LLVMContext *Ctx,
@@ -325,13 +346,14 @@ static bool parseDescriptorTable(LLVMContext *Ctx,
   const unsigned int NumOperands = DescriptorTableNode->getNumOperands();
   if (NumOperands < 2)
     return reportError(Ctx, "Invalid format for Descriptor Table");
-
+  bool HasError = false;
   dxbc::RTS0::v1::RootParameterHeader Header;
   if (std::optional<uint32_t> Val = extractMdIntValue(DescriptorTableNode, 1))
     Header.ShaderVisibility = *Val;
   else
-    return reportInvalidTypeError<ConstantInt>(Ctx, "DescriptorTableNode",
-                                               DescriptorTableNode, 1);
+    HasError = reportInvalidTypeError<ConstantInt>(Ctx, "DescriptorTableNode",
+                                                   DescriptorTableNode, 1) ||
+               HasError;
 
   mcdxbc::DescriptorTable Table;
   Header.ParameterType =
@@ -340,15 +362,16 @@ static bool parseDescriptorTable(LLVMContext *Ctx,
   for (unsigned int I = 2; I < NumOperands; I++) {
     MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I));
     if (Element == nullptr)
-      return reportInvalidTypeError<MDNode>(Ctx, "DescriptorTableNode",
-                                            DescriptorTableNode, I);
+      HasError = reportInvalidTypeError<MDNode>(Ctx, "DescriptorTableNode",
+                                                DescriptorTableNode, I) ||
+                 HasError;
 
     if (parseDescriptorRange(Ctx, RSD, Table, Element))
-      return true;
+      HasError = true || HasError;
   }
-
-  RSD.ParametersContainer.addParameter(Header, Table);
-  return false;
+  if (!HasError)
+    RSD.ParametersContainer.addParameter(Header, Table);
+  return HasError;
 }
 
 static bool parseStaticSampler(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
@@ -356,87 +379,101 @@ static bool parseStaticSampler(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
   if (StaticSamplerNode->getNumOperands() != 14)
     return reportError(Ctx, "Invalid format for Static Sampler");
 
+  bool HasError = false;
   dxbc::RTS0::v1::StaticSampler Sampler;
   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 1))
     Sampler.Filter = *Val;
   else
-    return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
-                                               StaticSamplerNode, 1);
+    HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+                                                   StaticSamplerNode, 1) ||
+               HasError;
 
   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 2))
     Sampler.AddressU = *Val;
   else
-    return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
-                                               StaticSamplerNode, 2);
+    HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+                                                   StaticSamplerNode, 2) ||
+               HasError;
 
   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 3))
     Sampler.AddressV = *Val;
   else
-    return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
-                                               StaticSamplerNode, 3);
+    HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+                                                   StaticSamplerNode, 3) ||
+               HasError;
 
   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 4))
     Sampler.AddressW = *Val;
   else
-    return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
-                                               StaticSamplerNode, 4);
+    HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+                                                   StaticSamplerNode, 4) ||
+               HasError;
 
   if (std::optional<APFloat> Val = extractMdFloatValue(StaticSamplerNode, 5))
     Sampler.MipLODBias = Val->convertToFloat();
   else
-    return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
-                                               StaticSamplerNode, 5);
+    HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+                                                   StaticSamplerNode, 5) ||
+               HasError;
 
   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 6))
     Sampler.MaxAnisotropy = *Val;
   else
-    return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
-                                               StaticSamplerNode, 6);
+    HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+                                                   StaticSamplerNode, 6) ||
+               HasError;
 
   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 7))
     Sampler.ComparisonFunc = *Val;
   else
-    return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
-                                               StaticSamplerNode, 7);
+    HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+                                                   StaticSamplerNode, 7) ||
+               HasError;
 
   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 8))
     Sampler.BorderColor = *Val;
   else
-    return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
-                                               StaticSamplerNode, 8);
+    HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+                                                   StaticSamplerNode, 8) ||
+               HasError;
 
   if (std::optional<APFloat> Val = extractMdFloatValue(StaticSamplerNode, 9))
     Sampler.MinLOD = Val->convertToFloat();
   else
-    return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
-                                               StaticSamplerNode, 9);
+    HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+                                                   StaticSamplerNode, 9) ||
+               HasError;
 
   if (std::optional<APFloat> Val = extractMdFloatValue(StaticSamplerNode, 10))
     Sampler.MaxLOD = Val->convertToFloat();
   else
-    return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
-                                               StaticSamplerNode, 10);
+    HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+                                                   StaticSamplerNode, 10) ||
+               HasError;
 
   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 11))
     Sampler.ShaderRegister = *Val;
   else
-    return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
-                                               StaticSamplerNode, 11);
+    HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+                                                   StaticSamplerNode, 11) ||
+               HasError;
 
   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 12))
     Sampler.RegisterSpace = *Val;
   else
-    return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
-                                               StaticSamplerNode, 12);
+    HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+                                                   StaticSamplerNode, 12) ||
+               HasError;
 
   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 13))
     Sampler.ShaderVisibility = *Val;
   else
-    return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
-                                               StaticSamplerNode, 13);
-
-  RSD.StaticSamplers.push_back(Sampler);
-  return false;
+    HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+                                                   StaticSamplerNode, 13) ||
+               HasError;
+  if (!HasError)
+    RSD.StaticSamplers.push_back(Sampler);
+  return HasError;
 }
 
 static bool parseRootSignatureElement(LLVMContext *Ctx,
@@ -488,7 +525,7 @@ static bool parse(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
     if (Element == nullptr)
       return reportError(Ctx, "Missing Root Element Metadata Node.");
 
-    HasError = HasError || parseRootSignatureElement(Ctx, RSD, Element);
+    HasError = parseRootSignatureElement(Ctx, RSD, Element) || HasError;
   }
 
   return HasError;
@@ -699,19 +736,20 @@ static bool verifyBorderColor(uint32_t BorderColor) {
 static bool verifyLOD(float LOD) { return !std::isnan(LOD); }
 
 static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) {
-
+  bool HasError = false;
   if (!verifyVersion(RSD.Version)) {
-    return reportValueError(Ctx, "Version", RSD.Version);
+    HasError = reportValueError(Ctx, "Version", RSD.Version) || HasError;
   }
 
   if (!verifyRootFlag(RSD.Flags)) {
-    return reportValueError(Ctx, "RootFlags", RSD.Flags);
+    HasError = reportValueError(Ctx, "RootFlags", RSD.Flags) || HasError;
   }
 
   for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) {
     if (!dxbc::isValidShaderVisibility(Info.Header.ShaderVisibility))
-      return reportValueError(Ctx, "ShaderVisibility",
-                              Info.Header.ShaderVisibility);
+      HasError = reportValueError(Ctx, "ShaderVisibility",
+                                  Info.Header.ShaderVisibility) ||
+                 HasError;
 
     assert(dxbc::isValidParameterType(Info.Header.ParameterType) &&
            "Invalid value for ParameterType");
@@ -724,15 +762,20 @@ static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) {
       const dxbc::RTS0::v2::RootDescriptor &Descriptor =
           RSD.ParametersContainer.getRootDescriptor(Info.Location);
       if (!verifyRegisterValue(Descriptor.ShaderRegister))
-        return reportValueError(Ctx, "ShaderRegister",
-                                Descriptor.ShaderRegister);
+        HasError = reportValueError(Ctx, "ShaderRegister",
+                                    Descriptor.ShaderRegister) ||
+                   HasError;
 
       if (!verifyRegisterSpace(Descriptor.RegisterSpace))
-        return reportValueError(Ctx, "RegisterSpace", Descriptor.RegisterSpace);
+        HasError =
+            reportValueError(Ctx, "RegisterSpace", Descriptor.RegisterSpace) ||
+            HasError;
 
       if (RSD.Version > 1)...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/144465


More information about the llvm-branch-commits mailing list