[llvm] [DirectX] Error handling improve in root signature metadata Parser (PR #149232)

Chris B via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 28 11:47:56 PDT 2025


================
@@ -523,145 +507,219 @@ bool MetadataParser::parseRootSignatureElement(LLVMContext *Ctx,
   switch (ElementKind) {
 
   case RootSignatureElementKind::RootFlags:
-    return parseRootFlags(Ctx, RSD, Element);
+    return parseRootFlags(RSD, Element);
   case RootSignatureElementKind::RootConstants:
-    return parseRootConstants(Ctx, RSD, Element);
+    return parseRootConstants(RSD, Element);
   case RootSignatureElementKind::CBV:
   case RootSignatureElementKind::SRV:
   case RootSignatureElementKind::UAV:
-    return parseRootDescriptors(Ctx, RSD, Element, ElementKind);
+    return parseRootDescriptors(RSD, Element, ElementKind);
   case RootSignatureElementKind::DescriptorTable:
-    return parseDescriptorTable(Ctx, RSD, Element);
+    return parseDescriptorTable(RSD, Element);
   case RootSignatureElementKind::StaticSamplers:
-    return parseStaticSampler(Ctx, RSD, Element);
+    return parseStaticSampler(RSD, Element);
   case RootSignatureElementKind::Error:
-    return reportError(Ctx, "Invalid Root Signature Element: " + *ElementText);
+    return make_error<GenericRSMetadataError>(
+        "Invalid Root Signature Element: " + *ElementText, Element);
   }
 
   llvm_unreachable("Unhandled RootSignatureElementKind enum.");
 }
 
-bool MetadataParser::validateRootSignature(
-    LLVMContext *Ctx, const llvm::mcdxbc::RootSignatureDesc &RSD) {
-  if (!llvm::hlsl::rootsig::verifyVersion(RSD.Version)) {
-    return reportValueError(Ctx, "Version", RSD.Version);
+Error MetadataParser::validateRootSignature(
+    const mcdxbc::RootSignatureDesc &RSD) {
+  Error DeferredErrs = Error::success();
+  if (!hlsl::rootsig::verifyVersion(RSD.Version)) {
+    DeferredErrs =
+        joinErrors(std::move(DeferredErrs),
+                   make_error<RootSignatureValidationError<uint32_t>>(
+                       "Version", RSD.Version));
   }
 
-  if (!llvm::hlsl::rootsig::verifyRootFlag(RSD.Flags)) {
-    return reportValueError(Ctx, "RootFlags", RSD.Flags);
+  if (!hlsl::rootsig::verifyRootFlag(RSD.Flags)) {
+    DeferredErrs =
+        joinErrors(std::move(DeferredErrs),
+                   make_error<RootSignatureValidationError<uint32_t>>(
+                       "RootFlags", RSD.Flags));
   }
 
   for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) {
     if (!dxbc::isValidShaderVisibility(Info.Header.ShaderVisibility))
-      return reportValueError(Ctx, "ShaderVisibility",
-                              Info.Header.ShaderVisibility);
+      DeferredErrs =
+          joinErrors(std::move(DeferredErrs),
+                     make_error<RootSignatureValidationError<uint32_t>>(
+                         "ShaderVisibility", Info.Header.ShaderVisibility));
 
     assert(dxbc::isValidParameterType(Info.Header.ParameterType) &&
            "Invalid value for ParameterType");
 
     switch (Info.Header.ParameterType) {
 
-    case llvm::to_underlying(dxbc::RootParameterType::CBV):
-    case llvm::to_underlying(dxbc::RootParameterType::UAV):
-    case llvm::to_underlying(dxbc::RootParameterType::SRV): {
+    case to_underlying(dxbc::RootParameterType::CBV):
+    case to_underlying(dxbc::RootParameterType::UAV):
+    case to_underlying(dxbc::RootParameterType::SRV): {
       const dxbc::RTS0::v2::RootDescriptor &Descriptor =
           RSD.ParametersContainer.getRootDescriptor(Info.Location);
-      if (!llvm::hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister))
-        return reportValueError(Ctx, "ShaderRegister",
-                                Descriptor.ShaderRegister);
-
-      if (!llvm::hlsl::rootsig::verifyRegisterSpace(Descriptor.RegisterSpace))
-        return reportValueError(Ctx, "RegisterSpace", Descriptor.RegisterSpace);
+      if (!hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister))
+        DeferredErrs =
+            joinErrors(std::move(DeferredErrs),
+                       make_error<RootSignatureValidationError<uint32_t>>(
+                           "ShaderRegister", Descriptor.ShaderRegister));
+
+      if (!hlsl::rootsig::verifyRegisterSpace(Descriptor.RegisterSpace))
+        DeferredErrs =
+            joinErrors(std::move(DeferredErrs),
+                       make_error<RootSignatureValidationError<uint32_t>>(
+                           "RegisterSpace", Descriptor.RegisterSpace));
 
       if (RSD.Version > 1) {
-        if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(RSD.Version,
-                                                           Descriptor.Flags))
-          return reportValueError(Ctx, "RootDescriptorFlag", Descriptor.Flags);
+        if (!hlsl::rootsig::verifyRootDescriptorFlag(RSD.Version,
+                                                     Descriptor.Flags))
+          DeferredErrs =
+              joinErrors(std::move(DeferredErrs),
+                         make_error<RootSignatureValidationError<uint32_t>>(
+                             "RootDescriptorFlag", Descriptor.Flags));
       }
       break;
     }
-    case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
+    case to_underlying(dxbc::RootParameterType::DescriptorTable): {
       const mcdxbc::DescriptorTable &Table =
           RSD.ParametersContainer.getDescriptorTable(Info.Location);
       for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) {
-        if (!llvm::hlsl::rootsig::verifyRangeType(Range.RangeType))
-          return reportValueError(Ctx, "RangeType", Range.RangeType);
-
-        if (!llvm::hlsl::rootsig::verifyRegisterSpace(Range.RegisterSpace))
-          return reportValueError(Ctx, "RegisterSpace", Range.RegisterSpace);
-
-        if (!llvm::hlsl::rootsig::verifyNumDescriptors(Range.NumDescriptors))
-          return reportValueError(Ctx, "NumDescriptors", Range.NumDescriptors);
-
-        if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(
+        if (!hlsl::rootsig::verifyRangeType(Range.RangeType))
+          DeferredErrs =
+              joinErrors(std::move(DeferredErrs),
+                         make_error<RootSignatureValidationError<uint32_t>>(
+                             "RangeType", Range.RangeType));
+
+        if (!hlsl::rootsig::verifyRegisterSpace(Range.RegisterSpace))
+          DeferredErrs =
+              joinErrors(std::move(DeferredErrs),
+                         make_error<RootSignatureValidationError<uint32_t>>(
+                             "RegisterSpace", Range.RegisterSpace));
+
+        if (!hlsl::rootsig::verifyNumDescriptors(Range.NumDescriptors))
+          DeferredErrs =
+              joinErrors(std::move(DeferredErrs),
+                         make_error<RootSignatureValidationError<uint32_t>>(
+                             "NumDescriptors", Range.NumDescriptors));
+
+        if (!hlsl::rootsig::verifyDescriptorRangeFlag(
                 RSD.Version, Range.RangeType, Range.Flags))
-          return reportValueError(Ctx, "DescriptorFlag", Range.Flags);
+          DeferredErrs =
+              joinErrors(std::move(DeferredErrs),
+                         make_error<RootSignatureValidationError<uint32_t>>(
+                             "DescriptorFlag", Range.Flags));
       }
       break;
     }
     }
   }
 
   for (const dxbc::RTS0::v1::StaticSampler &Sampler : RSD.StaticSamplers) {
-    if (!llvm::hlsl::rootsig::verifySamplerFilter(Sampler.Filter))
-      return reportValueError(Ctx, "Filter", Sampler.Filter);
-
-    if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressU))
-      return reportValueError(Ctx, "AddressU", Sampler.AddressU);
-
-    if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressV))
-      return reportValueError(Ctx, "AddressV", Sampler.AddressV);
-
-    if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressW))
-      return reportValueError(Ctx, "AddressW", Sampler.AddressW);
-
-    if (!llvm::hlsl::rootsig::verifyMipLODBias(Sampler.MipLODBias))
-      return reportValueError(Ctx, "MipLODBias", Sampler.MipLODBias);
-
-    if (!llvm::hlsl::rootsig::verifyMaxAnisotropy(Sampler.MaxAnisotropy))
-      return reportValueError(Ctx, "MaxAnisotropy", Sampler.MaxAnisotropy);
-
-    if (!llvm::hlsl::rootsig::verifyComparisonFunc(Sampler.ComparisonFunc))
-      return reportValueError(Ctx, "ComparisonFunc", Sampler.ComparisonFunc);
-
-    if (!llvm::hlsl::rootsig::verifyBorderColor(Sampler.BorderColor))
-      return reportValueError(Ctx, "BorderColor", Sampler.BorderColor);
-
-    if (!llvm::hlsl::rootsig::verifyLOD(Sampler.MinLOD))
-      return reportValueError(Ctx, "MinLOD", Sampler.MinLOD);
-
-    if (!llvm::hlsl::rootsig::verifyLOD(Sampler.MaxLOD))
-      return reportValueError(Ctx, "MaxLOD", Sampler.MaxLOD);
-
-    if (!llvm::hlsl::rootsig::verifyRegisterValue(Sampler.ShaderRegister))
-      return reportValueError(Ctx, "ShaderRegister", Sampler.ShaderRegister);
-
-    if (!llvm::hlsl::rootsig::verifyRegisterSpace(Sampler.RegisterSpace))
-      return reportValueError(Ctx, "RegisterSpace", Sampler.RegisterSpace);
+    if (!hlsl::rootsig::verifySamplerFilter(Sampler.Filter))
+      DeferredErrs =
+          joinErrors(std::move(DeferredErrs),
+                     make_error<RootSignatureValidationError<uint32_t>>(
+                         "Filter", Sampler.Filter));
+
+    if (!hlsl::rootsig::verifyAddress(Sampler.AddressU))
+      DeferredErrs =
+          joinErrors(std::move(DeferredErrs),
+                     make_error<RootSignatureValidationError<uint32_t>>(
+                         "AddressU", Sampler.AddressU));
+
+    if (!hlsl::rootsig::verifyAddress(Sampler.AddressV))
+      DeferredErrs =
+          joinErrors(std::move(DeferredErrs),
+                     make_error<RootSignatureValidationError<uint32_t>>(
+                         "AddressV", Sampler.AddressV));
+
+    if (!hlsl::rootsig::verifyAddress(Sampler.AddressW))
+      DeferredErrs =
+          joinErrors(std::move(DeferredErrs),
+                     make_error<RootSignatureValidationError<uint32_t>>(
+                         "AddressW", Sampler.AddressW));
+
+    if (!hlsl::rootsig::verifyMipLODBias(Sampler.MipLODBias))
+      DeferredErrs = joinErrors(std::move(DeferredErrs),
+                                make_error<RootSignatureValidationError<float>>(
+                                    "MipLODBias", Sampler.MipLODBias));
+
+    if (!hlsl::rootsig::verifyMaxAnisotropy(Sampler.MaxAnisotropy))
+      DeferredErrs =
+          joinErrors(std::move(DeferredErrs),
+                     make_error<RootSignatureValidationError<uint32_t>>(
+                         "MaxAnisotropy", Sampler.MaxAnisotropy));
+
+    if (!hlsl::rootsig::verifyComparisonFunc(Sampler.ComparisonFunc))
+      DeferredErrs =
+          joinErrors(std::move(DeferredErrs),
+                     make_error<RootSignatureValidationError<uint32_t>>(
+                         "ComparisonFunc", Sampler.ComparisonFunc));
+
+    if (!hlsl::rootsig::verifyBorderColor(Sampler.BorderColor))
+      DeferredErrs =
+          joinErrors(std::move(DeferredErrs),
+                     make_error<RootSignatureValidationError<uint32_t>>(
+                         "BorderColor", Sampler.BorderColor));
+
+    if (!hlsl::rootsig::verifyLOD(Sampler.MinLOD))
+      DeferredErrs = joinErrors(std::move(DeferredErrs),
+                                make_error<RootSignatureValidationError<float>>(
+                                    "MinLOD", Sampler.MinLOD));
+
+    if (!hlsl::rootsig::verifyLOD(Sampler.MaxLOD))
+      DeferredErrs = joinErrors(std::move(DeferredErrs),
+                                make_error<RootSignatureValidationError<float>>(
+                                    "MaxLOD", Sampler.MaxLOD));
+
+    if (!hlsl::rootsig::verifyRegisterValue(Sampler.ShaderRegister))
+      DeferredErrs =
+          joinErrors(std::move(DeferredErrs),
+                     make_error<RootSignatureValidationError<uint32_t>>(
+                         "ShaderRegister", Sampler.ShaderRegister));
+
+    if (!hlsl::rootsig::verifyRegisterSpace(Sampler.RegisterSpace))
+      DeferredErrs =
+          joinErrors(std::move(DeferredErrs),
+                     make_error<RootSignatureValidationError<uint32_t>>(
+                         "RegisterSpace", Sampler.RegisterSpace));
 
     if (!dxbc::isValidShaderVisibility(Sampler.ShaderVisibility))
-      return reportValueError(Ctx, "ShaderVisibility",
-                              Sampler.ShaderVisibility);
+      DeferredErrs =
+          joinErrors(std::move(DeferredErrs),
+                     make_error<RootSignatureValidationError<uint32_t>>(
+                         "ShaderVisibility", Sampler.ShaderVisibility));
   }
 
-  return false;
+  return DeferredErrs;
 }
 
-bool MetadataParser::ParseRootSignature(LLVMContext *Ctx,
-                                        mcdxbc::RootSignatureDesc &RSD) {
-  bool HasError = false;
-
-  // Loop through the Root Elements of the root signature.
+Expected<mcdxbc::RootSignatureDesc>
+MetadataParser::ParseRootSignature(uint32_t Version) {
+  Error DeferredErrs = Error::success();
+  mcdxbc::RootSignatureDesc RSD;
+  RSD.Version = Version;
   for (const auto &Operand : Root->operands()) {
     MDNode *Element = dyn_cast<MDNode>(Operand);
     if (Element == nullptr)
-      return reportError(Ctx, "Missing Root Element Metadata Node.");
+      return joinErrors(std::move(DeferredErrs),
+                        make_error<GenericRSMetadataError>(
+                            "Missing Root Element Metadata Node.", nullptr));
 
-    HasError = HasError || parseRootSignatureElement(Ctx, RSD, Element) ||
-               validateRootSignature(Ctx, RSD);
+    if (auto Err = parseRootSignatureElement(RSD, Element)) {
+      DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err));
+    }
----------------
llvm-beanz wrote:

```suggestion
    if (auto Err = parseRootSignatureElement(RSD, Element))
      DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err));
```

https://github.com/llvm/llvm-project/pull/149232


More information about the llvm-commits mailing list