[llvm] 58ce3e2 - [DirectX] Fix Flags validation to prevent casting into enum (#161587)

via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 6 14:21:33 PDT 2025


Author: joaosaffran
Date: 2025-10-06T17:21:28-04:00
New Revision: 58ce3e20e55c67298da94eabe4e25f8baeb45fa7

URL: https://github.com/llvm/llvm-project/commit/58ce3e20e55c67298da94eabe4e25f8baeb45fa7
DIFF: https://github.com/llvm/llvm-project/commit/58ce3e20e55c67298da94eabe4e25f8baeb45fa7.diff

LOG: [DirectX] Fix Flags validation to prevent casting into enum (#161587)

This PR changes the validation logic for Root Descriptor and Descriptor
Range flags to properly check if the `uint32_t` values are within range
before casting into the enums.

Added: 
    llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-Invalid-Flag-LargeNumber.ll
    llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootDescriptor-Invalid-Flags-LargeNumber.ll

Modified: 
    clang/lib/Sema/SemaHLSL.cpp
    llvm/include/llvm/BinaryFormat/DXContainer.h
    llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h
    llvm/lib/BinaryFormat/DXContainer.cpp
    llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
    llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp

Removed: 
    


################################################################################
diff  --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 2b375b93e3a48..a662b72c2a362 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1289,8 +1289,8 @@ bool SemaHLSL::handleRootSignatureElements(
       VerifyRegister(Loc, Descriptor->Reg.Number);
       VerifySpace(Loc, Descriptor->Space);
 
-      if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(
-              Version, llvm::to_underlying(Descriptor->Flags)))
+      if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(Version,
+                                                         Descriptor->Flags))
         ReportFlagError(Loc);
     } else if (const auto *Constants =
                    std::get_if<llvm::hlsl::rootsig::RootConstants>(&Elem)) {

diff  --git a/llvm/include/llvm/BinaryFormat/DXContainer.h b/llvm/include/llvm/BinaryFormat/DXContainer.h
index 0b5646229e8b5..b9a08ce1ca14e 100644
--- a/llvm/include/llvm/BinaryFormat/DXContainer.h
+++ b/llvm/include/llvm/BinaryFormat/DXContainer.h
@@ -248,6 +248,12 @@ enum class StaticBorderColor : uint32_t {
 
 bool isValidBorderColor(uint32_t V);
 
+bool isValidRootDesciptorFlags(uint32_t V);
+
+bool isValidDescriptorRangeFlags(uint32_t V);
+
+bool isValidStaticSamplerFlags(uint32_t V);
+
 LLVM_ABI ArrayRef<EnumEntry<StaticBorderColor>> getStaticBorderColors();
 
 LLVM_ABI PartType parsePartType(StringRef S);

diff  --git a/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h b/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h
index 4dd18111b0c9d..7131980e9ff3a 100644
--- a/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h
+++ b/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h
@@ -28,12 +28,14 @@ LLVM_ABI bool verifyRootFlag(uint32_t Flags);
 LLVM_ABI bool verifyVersion(uint32_t Version);
 LLVM_ABI bool verifyRegisterValue(uint32_t RegisterValue);
 LLVM_ABI bool verifyRegisterSpace(uint32_t RegisterSpace);
-LLVM_ABI bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal);
+LLVM_ABI bool verifyRootDescriptorFlag(uint32_t Version,
+                                       dxbc::RootDescriptorFlags Flags);
 LLVM_ABI bool verifyRangeType(uint32_t Type);
 LLVM_ABI bool verifyDescriptorRangeFlag(uint32_t Version,
                                         dxil::ResourceClass Type,
-                                        dxbc::DescriptorRangeFlags FlagsVal);
-LLVM_ABI bool verifyStaticSamplerFlags(uint32_t Version, uint32_t FlagsNumber);
+                                        dxbc::DescriptorRangeFlags Flags);
+LLVM_ABI bool verifyStaticSamplerFlags(uint32_t Version,
+                                       dxbc::StaticSamplerFlags Flags);
 LLVM_ABI bool verifyNumDescriptors(uint32_t NumDescriptors);
 LLVM_ABI bool verifyMipLODBias(float MipLODBias);
 LLVM_ABI bool verifyMaxAnisotropy(uint32_t MaxAnisotropy);

diff  --git a/llvm/lib/BinaryFormat/DXContainer.cpp b/llvm/lib/BinaryFormat/DXContainer.cpp
index b334f86568acb..22f518067b318 100644
--- a/llvm/lib/BinaryFormat/DXContainer.cpp
+++ b/llvm/lib/BinaryFormat/DXContainer.cpp
@@ -82,6 +82,27 @@ bool llvm::dxbc::isValidBorderColor(uint32_t V) {
   return false;
 }
 
+bool llvm::dxbc::isValidRootDesciptorFlags(uint32_t V) {
+  using FlagT = dxbc::RootDescriptorFlags;
+  uint32_t LargestValue =
+      llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR);
+  return V < NextPowerOf2(LargestValue);
+}
+
+bool llvm::dxbc::isValidDescriptorRangeFlags(uint32_t V) {
+  using FlagT = dxbc::DescriptorRangeFlags;
+  uint32_t LargestValue =
+      llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR);
+  return V < NextPowerOf2(LargestValue);
+}
+
+bool llvm::dxbc::isValidStaticSamplerFlags(uint32_t V) {
+  using FlagT = dxbc::StaticSamplerFlags;
+  uint32_t LargestValue =
+      llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR);
+  return V < NextPowerOf2(LargestValue);
+}
+
 dxbc::PartType dxbc::parsePartType(StringRef S) {
 #define CONTAINER_PART(PartName) .Case(#PartName, PartType::PartName)
   return StringSwitch<dxbc::PartType>(S)

diff  --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
index 7a0cf408968de..707f0c368e9d8 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
@@ -651,8 +651,11 @@ Error MetadataParser::validateRootSignature(
                            "RegisterSpace", Descriptor.RegisterSpace));
 
       if (RSD.Version > 1) {
-        if (!hlsl::rootsig::verifyRootDescriptorFlag(RSD.Version,
-                                                     Descriptor.Flags))
+        bool IsValidFlag =
+            dxbc::isValidRootDesciptorFlags(Descriptor.Flags) &&
+            hlsl::rootsig::verifyRootDescriptorFlag(
+                RSD.Version, dxbc::RootDescriptorFlags(Descriptor.Flags));
+        if (!IsValidFlag)
           DeferredErrs =
               joinErrors(std::move(DeferredErrs),
                          make_error<RootSignatureValidationError<uint32_t>>(
@@ -676,9 +679,11 @@ Error MetadataParser::validateRootSignature(
                          make_error<RootSignatureValidationError<uint32_t>>(
                              "NumDescriptors", Range.NumDescriptors));
 
-        if (!hlsl::rootsig::verifyDescriptorRangeFlag(
-                RSD.Version, Range.RangeType,
-                dxbc::DescriptorRangeFlags(Range.Flags)))
+        bool IsValidFlag = dxbc::isValidDescriptorRangeFlags(Range.Flags) &&
+                           hlsl::rootsig::verifyDescriptorRangeFlag(
+                               RSD.Version, Range.RangeType,
+                               dxbc::DescriptorRangeFlags(Range.Flags));
+        if (!IsValidFlag)
           DeferredErrs =
               joinErrors(std::move(DeferredErrs),
                          make_error<RootSignatureValidationError<uint32_t>>(
@@ -731,8 +736,11 @@ Error MetadataParser::validateRootSignature(
           joinErrors(std::move(DeferredErrs),
                      make_error<RootSignatureValidationError<uint32_t>>(
                          "RegisterSpace", Sampler.RegisterSpace));
-
-    if (!hlsl::rootsig::verifyStaticSamplerFlags(RSD.Version, Sampler.Flags))
+    bool IsValidFlag =
+        dxbc::isValidStaticSamplerFlags(Sampler.Flags) &&
+        hlsl::rootsig::verifyStaticSamplerFlags(
+            RSD.Version, dxbc::StaticSamplerFlags(Sampler.Flags));
+    if (!IsValidFlag)
       DeferredErrs =
           joinErrors(std::move(DeferredErrs),
                      make_error<RootSignatureValidationError<uint32_t>>(

diff  --git a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
index 8a2b03d9ede8b..30408dfda940d 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
@@ -34,7 +34,8 @@ bool verifyRegisterSpace(uint32_t RegisterSpace) {
   return !(RegisterSpace >= 0xFFFFFFF0);
 }
 
-bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) {
+bool verifyRootDescriptorFlag(uint32_t Version,
+                              dxbc::RootDescriptorFlags FlagsVal) {
   using FlagT = dxbc::RootDescriptorFlags;
   FlagT Flags = FlagT(FlagsVal);
   if (Version == 1)
@@ -56,7 +57,6 @@ bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) {
 bool verifyDescriptorRangeFlag(uint32_t Version, dxil::ResourceClass Type,
                                dxbc::DescriptorRangeFlags Flags) {
   using FlagT = dxbc::DescriptorRangeFlags;
-
   const bool IsSampler = (Type == dxil::ResourceClass::Sampler);
 
   if (Version == 1) {
@@ -113,13 +113,8 @@ bool verifyDescriptorRangeFlag(uint32_t Version, dxil::ResourceClass Type,
   return (Flags & ~Mask) == FlagT::None;
 }
 
-bool verifyStaticSamplerFlags(uint32_t Version, uint32_t FlagsNumber) {
-  uint32_t LargestValue = llvm::to_underlying(
-      dxbc::StaticSamplerFlags::LLVM_BITMASK_LARGEST_ENUMERATOR);
-  if (FlagsNumber >= NextPowerOf2(LargestValue))
-    return false;
-
-  dxbc::StaticSamplerFlags Flags = dxbc::StaticSamplerFlags(FlagsNumber);
+bool verifyStaticSamplerFlags(uint32_t Version,
+                              dxbc::StaticSamplerFlags Flags) {
   if (Version <= 2)
     return Flags == dxbc::StaticSamplerFlags::None;
 

diff  --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-Invalid-Flag-LargeNumber.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-Invalid-Flag-LargeNumber.ll
new file mode 100644
index 0000000000000..c27c87ff057d5
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-Invalid-Flag-LargeNumber.ll
@@ -0,0 +1,20 @@
+; RUN: not opt -passes='print<dxil-root-signature>' %s -S -o - 2>&1 | FileCheck %s
+
+target triple = "dxil-unknown-shadermodel6.0-compute"
+
+; CHECK: error: Invalid value for DescriptorFlag: 66666
+; CHECK-NOT: Root Signature Definitions
+
+define void @main() #0 {
+entry:
+  ret void
+}
+attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
+
+
+!dx.rootsignatures = !{!2} ; list of function/root signature pairs
+!2 = !{ ptr @main, !3, i32 2 } ; function, root signature
+!3 = !{ !5 } ; list of root signature elements
+!5 = !{ !"DescriptorTable", i32 0, !6, !7 }
+!6 = !{ !"SRV", i32 1, i32 1, i32 0, i32 -1, i32 66666 }
+!7 = !{ !"UAV", i32 5, i32 1, i32 10, i32 5, i32 2 }

diff  --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootDescriptor-Invalid-Flags-LargeNumber.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootDescriptor-Invalid-Flags-LargeNumber.ll
new file mode 100644
index 0000000000000..898e197c7e0cc
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootDescriptor-Invalid-Flags-LargeNumber.ll
@@ -0,0 +1,18 @@
+; RUN: not opt -passes='print<dxil-root-signature>' %s -S -o - 2>&1 | FileCheck %s
+
+target triple = "dxil-unknown-shadermodel6.0-compute"
+
+
+; CHECK: error: Invalid value for RootDescriptorFlag: 666
+; CHECK-NOT: Root Signature Definitions
+define void @main() #0 {
+entry:
+  ret void
+}
+attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
+
+
+!dx.rootsignatures = !{!2} ; list of function/root signature pairs
+!2 = !{ ptr @main, !3, i32 2 } ; function, root signature
+!3 = !{ !5 } ; list of root signature elements
+!5 = !{ !"RootCBV", i32 0, i32 1, i32 2, i32 666  }


        


More information about the llvm-commits mailing list