[llvm-branch-commits] [llvm] [DirectX] Improve error handling and validation in root signature parsing (PR #144577)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Jun 17 11:11:39 PDT 2025
https://github.com/joaosaffran created https://github.com/llvm/llvm-project/pull/144577
This patch enhances error handling and validation in the DirectX backend's root signature parsing. The changes include:
1. **Improved Error Reporting**:
- Introduced `reportInvalidTypeError` utility to provide detailed error messages for type mismatches.
- Enhanced diagnostic messages for invalid metadata nodes and values.
2. **Validation Updates**:
- Added stricter validation for descriptor tables and static samplers.
- Improved handling of invalid values for filter modes, address modes, and LOD parameters.
Example changes:
```cpp
if (Element == nullptr)
return reportInvalidTypeError<MDNode>(Ctx, "DescriptorTableNode", DescriptorTableNode, I);
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 1))
Sampler.Filter = *Val;
else
return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode", StaticSamplerNode, 1);
```
Testing:
- Validation of invalid metadata nodes and values.
- Proper diagnostic messages for type mismatches.
- All existing DirectX backend tests continue to pass.
>From 02f1f21b8ecc608341440c573483e69c161a06d4 Mon Sep 17 00:00:00 2001
From: joaosaffran <joao.saffran at microsoft.com>
Date: Fri, 6 Jun 2025 20:04:00 +0000
Subject: [PATCH 1/2] changing error message
---
llvm/lib/Target/DirectX/DXILRootSignature.cpp | 119 +++++++++++++++---
...re-RootConstants-Invalid-Num32BitValues.ll | 2 +-
...ure-RootConstants-Invalid-RegisterSpace.ll | 2 +-
...re-RootConstants-Invalid-ShaderRegister.ll | 2 +-
4 files changed, 104 insertions(+), 21 deletions(-)
diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
index 3aef7d3eb1e69..3a27afc6c660f 100644
--- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp
+++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
@@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//
#include "DXILRootSignature.h"
#include "DirectX.h"
+#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Analysis/DXILMetadataAnalysis.h"
@@ -30,6 +31,7 @@
#include <cmath>
#include <cstdint>
#include <optional>
+#include <string>
#include <utility>
using namespace llvm;
@@ -48,6 +50,71 @@ static bool reportValueError(LLVMContext *Ctx, Twine ParamName,
return true;
}
+// Template function to get formatted type string based on C++ type
+template <typename T> std::string getTypeFormatted() {
+ if constexpr (std::is_same_v<T, MDString>) {
+ return "string";
+ } else if constexpr (std::is_same_v<T, MDNode *> ||
+ std::is_same_v<T, const MDNode *>) {
+ return "metadata";
+ } else if constexpr (std::is_same_v<T, ConstantAsMetadata *> ||
+ std::is_same_v<T, const ConstantAsMetadata *>) {
+ return "constant";
+ } else if constexpr (std::is_same_v<T, ConstantAsMetadata>) {
+ return "constant";
+ } else if constexpr (std::is_same_v<T, ConstantInt *> ||
+ std::is_same_v<T, const ConstantInt *>) {
+ return "constant int";
+ } else if constexpr (std::is_same_v<T, ConstantInt>) {
+ return "constant int";
+ }
+ return "unknown";
+}
+
+// Helper function to get the actual type of a metadata operand
+std::string getActualMDType(const MDNode *Node, unsigned Index) {
+ if (!Node || Index >= Node->getNumOperands())
+ return "null";
+
+ Metadata *Op = Node->getOperand(Index);
+ if (!Op)
+ return "null";
+
+ if (isa<MDString>(Op))
+ return getTypeFormatted<MDString>();
+
+ if (isa<ConstantAsMetadata>(Op)) {
+ if (auto *CAM = dyn_cast<ConstantAsMetadata>(Op)) {
+ Type *T = CAM->getValue()->getType();
+ if (T->isIntegerTy())
+ return (Twine("i") + Twine(T->getIntegerBitWidth())).str();
+ if (T->isFloatingPointTy())
+ return T->isFloatTy() ? getTypeFormatted<float>()
+ : T->isDoubleTy() ? getTypeFormatted<double>()
+ : "fp";
+
+ return getTypeFormatted<ConstantAsMetadata>();
+ }
+ }
+ if (isa<MDNode>(Op))
+ return getTypeFormatted<MDNode *>();
+
+ return "unknown";
+}
+
+// Helper function to simplify error reporting for invalid metadata values
+template <typename ET>
+auto reportInvalidTypeError(LLVMContext *Ctx, Twine ParamName,
+ const MDNode *Node, unsigned Index) {
+ std::string ExpectedType = getTypeFormatted<ET>();
+ std::string ActualType = getActualMDType(Node, Index);
+
+ return reportError(Ctx, "Root Signature Node: " + ParamName +
+ " expected metadata node of type " +
+ ExpectedType + " at index " + Twine(Index) +
+ " but got " + ActualType);
+}
+
static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
unsigned int OpId) {
if (auto *CI =
@@ -80,7 +147,8 @@ static bool parseRootFlags(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1))
RSD.Flags = *Val;
else
- return reportError(Ctx, "Invalid value for RootFlag");
+ return reportInvalidTypeError<ConstantInt>(Ctx, "RootFlagNode",
+ RootFlagNode, 1);
return false;
}
@@ -100,23 +168,27 @@ static bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1))
Header.ShaderVisibility = *Val;
else
- return reportError(Ctx, "Invalid value for ShaderVisibility");
+ return reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode",
+ RootConstantNode, 1);
dxbc::RTS0::v1::RootConstants Constants;
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2))
Constants.ShaderRegister = *Val;
else
- return reportError(Ctx, "Invalid value for ShaderRegister");
+ return reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode",
+ RootConstantNode, 2);
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3))
Constants.RegisterSpace = *Val;
else
- return reportError(Ctx, "Invalid value for RegisterSpace");
+ return reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode",
+ RootConstantNode, 3);
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4))
Constants.Num32BitValues = *Val;
else
- return reportError(Ctx, "Invalid value for Num32BitValues");
+ return reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode",
+ RootConstantNode, 4);
RSD.ParametersContainer.addParameter(Header, Constants);
@@ -154,18 +226,21 @@ static bool parseRootDescriptors(LLVMContext *Ctx,
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 1))
Header.ShaderVisibility = *Val;
else
- return reportError(Ctx, "Invalid value for ShaderVisibility");
+ return reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode",
+ RootDescriptorNode, 1);
dxbc::RTS0::v2::RootDescriptor Descriptor;
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2))
Descriptor.ShaderRegister = *Val;
else
- return reportError(Ctx, "Invalid value for ShaderRegister");
+ return reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode",
+ RootDescriptorNode, 2);
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 3))
Descriptor.RegisterSpace = *Val;
else
- return reportError(Ctx, "Invalid value for RegisterSpace");
+ return reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode",
+ RootDescriptorNode, 3);
if (RSD.Version == 1) {
RSD.ParametersContainer.addParameter(Header, Descriptor);
@@ -176,7 +251,8 @@ static bool parseRootDescriptors(LLVMContext *Ctx,
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 4))
Descriptor.Flags = *Val;
else
- return reportError(Ctx, "Invalid value for Root Descriptor Flags");
+ return reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode",
+ RootDescriptorNode, 4);
RSD.ParametersContainer.addParameter(Header, Descriptor);
return false;
@@ -196,7 +272,8 @@ static bool parseDescriptorRange(LLVMContext *Ctx,
extractMdStringValue(RangeDescriptorNode, 0);
if (!ElementText.has_value())
- return reportError(Ctx, "Descriptor Range, first element is not a string.");
+ return reportInvalidTypeError<MDString>(Ctx, "RangeDescriptorNode",
+ RangeDescriptorNode, 0);
Range.RangeType =
StringSwitch<uint32_t>(*ElementText)
@@ -213,28 +290,32 @@ static bool parseDescriptorRange(LLVMContext *Ctx,
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 1))
Range.NumDescriptors = *Val;
else
- return reportError(Ctx, "Invalid value for Number of Descriptor in Range");
+ return reportInvalidTypeError<MDString>(Ctx, "RangeDescriptorNode",
+ RangeDescriptorNode, 1);
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 2))
Range.BaseShaderRegister = *Val;
else
- return reportError(Ctx, "Invalid value for BaseShaderRegister");
+ return reportInvalidTypeError<MDString>(Ctx, "RangeDescriptorNode",
+ RangeDescriptorNode, 2);
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 3))
Range.RegisterSpace = *Val;
else
- return reportError(Ctx, "Invalid value for RegisterSpace");
+ return reportInvalidTypeError<MDString>(Ctx, "RangeDescriptorNode",
+ RangeDescriptorNode, 3);
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 4))
Range.OffsetInDescriptorsFromTableStart = *Val;
else
- return reportError(Ctx,
- "Invalid value for OffsetInDescriptorsFromTableStart");
+ return reportInvalidTypeError<MDString>(Ctx, "RangeDescriptorNode",
+ RangeDescriptorNode, 4);
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 5))
Range.Flags = *Val;
else
- return reportError(Ctx, "Invalid value for Descriptor Range Flags");
+ return reportInvalidTypeError<MDString>(Ctx, "RangeDescriptorNode",
+ RangeDescriptorNode, 5);
Table.Ranges.push_back(Range);
return false;
@@ -251,7 +332,8 @@ static bool parseDescriptorTable(LLVMContext *Ctx,
if (std::optional<uint32_t> Val = extractMdIntValue(DescriptorTableNode, 1))
Header.ShaderVisibility = *Val;
else
- return reportError(Ctx, "Invalid value for ShaderVisibility");
+ return reportInvalidTypeError<MDString>(Ctx, "DescriptorTableNode",
+ DescriptorTableNode, 1);
mcdxbc::DescriptorTable Table;
Header.ParameterType =
@@ -260,7 +342,8 @@ static bool parseDescriptorTable(LLVMContext *Ctx,
for (unsigned int I = 2; I < NumOperands; I++) {
MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I));
if (Element == nullptr)
- return reportError(Ctx, "Missing Root Element Metadata Node.");
+ return reportInvalidTypeError<MDNode>(Ctx, "DescriptorTableNode",
+ DescriptorTableNode, I);
if (parseDescriptorRange(Ctx, RSD, Table, Element))
return true;
diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-Num32BitValues.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-Num32BitValues.ll
index 552c128e5ab57..0d5bbdfc097c4 100644
--- a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-Num32BitValues.ll
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-Num32BitValues.ll
@@ -2,7 +2,7 @@
target triple = "dxil-unknown-shadermodel6.0-compute"
-; CHECK: error: Invalid value for Num32BitValues
+; CHECK: error: Root Signature Node: RootConstantNode expected metadata node of type constant int at index 4 but got string
; CHECK-NOT: Root Signature Definitions
define void @main() {
diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-RegisterSpace.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-RegisterSpace.ll
index 1087b414942e2..1384da4baca98 100644
--- a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-RegisterSpace.ll
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-RegisterSpace.ll
@@ -2,7 +2,7 @@
target triple = "dxil-unknown-shadermodel6.0-compute"
-; CHECK: error: Invalid value for RegisterSpace
+; CHECK: error: Root Signature Node: RootConstantNode expected metadata node of type constant int at index 3 but got string
; CHECK-NOT: Root Signature Definitions
define void @main() #0 {
diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-ShaderRegister.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-ShaderRegister.ll
index 53fd924e8f46e..e1fd6a4414609 100644
--- a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-ShaderRegister.ll
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-ShaderRegister.ll
@@ -2,7 +2,7 @@
target triple = "dxil-unknown-shadermodel6.0-compute"
-; CHECK: error: Invalid value for ShaderRegister
+; CHECK: error: Root Signature Node: RootConstantNode expected metadata node of type constant int at index 2 but got string
; CHECK-NOT: Root Signature Definitions
define void @main() #0 {
>From e62419f82edd38bb027f3451dc350ecb01b0be2c Mon Sep 17 00:00:00 2001
From: joaosaffran <joao.saffran at microsoft.com>
Date: Mon, 16 Jun 2025 19:50:29 +0000
Subject: [PATCH 2/2] clean up
---
llvm/lib/Target/DirectX/DXILRootSignature.cpp | 65 +++++++++++--------
1 file changed, 38 insertions(+), 27 deletions(-)
diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
index 3a27afc6c660f..57d5ee8ac467c 100644
--- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp
+++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
@@ -12,7 +12,6 @@
//===----------------------------------------------------------------------===//
#include "DXILRootSignature.h"
#include "DirectX.h"
-#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Analysis/DXILMetadataAnalysis.h"
@@ -31,7 +30,6 @@
#include <cmath>
#include <cstdint>
#include <optional>
-#include <string>
#include <utility>
using namespace llvm;
@@ -290,32 +288,32 @@ static bool parseDescriptorRange(LLVMContext *Ctx,
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 1))
Range.NumDescriptors = *Val;
else
- return reportInvalidTypeError<MDString>(Ctx, "RangeDescriptorNode",
- RangeDescriptorNode, 1);
+ return reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
+ RangeDescriptorNode, 1);
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 2))
Range.BaseShaderRegister = *Val;
else
- return reportInvalidTypeError<MDString>(Ctx, "RangeDescriptorNode",
- RangeDescriptorNode, 2);
+ return reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
+ RangeDescriptorNode, 2);
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 3))
Range.RegisterSpace = *Val;
else
- return reportInvalidTypeError<MDString>(Ctx, "RangeDescriptorNode",
- RangeDescriptorNode, 3);
+ return reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
+ RangeDescriptorNode, 3);
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 4))
Range.OffsetInDescriptorsFromTableStart = *Val;
else
- return reportInvalidTypeError<MDString>(Ctx, "RangeDescriptorNode",
- RangeDescriptorNode, 4);
+ return reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
+ RangeDescriptorNode, 4);
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 5))
Range.Flags = *Val;
else
- return reportInvalidTypeError<MDString>(Ctx, "RangeDescriptorNode",
- RangeDescriptorNode, 5);
+ return reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
+ RangeDescriptorNode, 5);
Table.Ranges.push_back(Range);
return false;
@@ -332,8 +330,8 @@ static bool parseDescriptorTable(LLVMContext *Ctx,
if (std::optional<uint32_t> Val = extractMdIntValue(DescriptorTableNode, 1))
Header.ShaderVisibility = *Val;
else
- return reportInvalidTypeError<MDString>(Ctx, "DescriptorTableNode",
- DescriptorTableNode, 1);
+ return reportInvalidTypeError<ConstantInt>(Ctx, "DescriptorTableNode",
+ DescriptorTableNode, 1);
mcdxbc::DescriptorTable Table;
Header.ParameterType =
@@ -362,67 +360,80 @@ static bool parseStaticSampler(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 1))
Sampler.Filter = *Val;
else
- return reportError(Ctx, "Invalid value for Filter");
+ return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 1);
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 2))
Sampler.AddressU = *Val;
else
- return reportError(Ctx, "Invalid value for AddressU");
+ return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 2);
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 3))
Sampler.AddressV = *Val;
else
- return reportError(Ctx, "Invalid value for AddressV");
+ return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 3);
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 4))
Sampler.AddressW = *Val;
else
- return reportError(Ctx, "Invalid value for AddressW");
+ return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 4);
if (std::optional<APFloat> Val = extractMdFloatValue(StaticSamplerNode, 5))
Sampler.MipLODBias = Val->convertToFloat();
else
- return reportError(Ctx, "Invalid value for MipLODBias");
+ return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 5);
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 6))
Sampler.MaxAnisotropy = *Val;
else
- return reportError(Ctx, "Invalid value for MaxAnisotropy");
+ return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 6);
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 7))
Sampler.ComparisonFunc = *Val;
else
- return reportError(Ctx, "Invalid value for ComparisonFunc ");
+ return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 7);
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 8))
Sampler.BorderColor = *Val;
else
- return reportError(Ctx, "Invalid value for ComparisonFunc ");
+ return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 8);
if (std::optional<APFloat> Val = extractMdFloatValue(StaticSamplerNode, 9))
Sampler.MinLOD = Val->convertToFloat();
else
- return reportError(Ctx, "Invalid value for MinLOD");
+ return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 9);
if (std::optional<APFloat> Val = extractMdFloatValue(StaticSamplerNode, 10))
Sampler.MaxLOD = Val->convertToFloat();
else
- return reportError(Ctx, "Invalid value for MaxLOD");
+ return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 10);
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 11))
Sampler.ShaderRegister = *Val;
else
- return reportError(Ctx, "Invalid value for ShaderRegister");
+ return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 11);
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 12))
Sampler.RegisterSpace = *Val;
else
- return reportError(Ctx, "Invalid value for RegisterSpace");
+ return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 12);
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 13))
Sampler.ShaderVisibility = *Val;
else
- return reportError(Ctx, "Invalid value for ShaderVisibility");
+ return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 13);
RSD.StaticSamplers.push_back(Sampler);
return false;
More information about the llvm-branch-commits
mailing list