[clang] [llvm] [DirectX] Fix Flags validation to prevent casting into enum (PR #161587)

via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 1 14:00:33 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-directx

Author: None (joaosaffran)

<details>
<summary>Changes</summary>

This PR changes the validation logic for Root Descriptor and Descriptor Range flags to properly check if the `uint32_t` values are within range before casting into the enums. 

---
Full diff: https://github.com/llvm/llvm-project/pull/161587.diff


6 Files Affected:

- (modified) clang/lib/Sema/SemaHLSL.cpp (+2-2) 
- (modified) llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h (+1-1) 
- (modified) llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp (+1-2) 
- (modified) llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp (+11-1) 
- (added) llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-Invalid-Flag-LargeNumber.ll (+20) 
- (added) llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootDescriptor-Invalid-Flags-LargeNumber.ll (+18) 


``````````diff
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 129b03c07c0bd..a2e8afb9bb8ff 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1322,8 +1322,8 @@ bool SemaHLSL::handleRootSignatureElements(
         ReportError(Loc, 1, 0xfffffffe);
       }
 
-      if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(Version, Clause->Type,
-                                                          Clause->Flags))
+      if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(
+              Version, Clause->Type, llvm::to_underlying(Clause->Flags)))
         ReportFlagError(Loc);
     }
   }
diff --git a/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h b/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h
index 4dd18111b0c9d..10723a181f025 100644
--- a/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h
+++ b/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h
@@ -32,7 +32,7 @@ LLVM_ABI bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal);
 LLVM_ABI bool verifyRangeType(uint32_t Type);
 LLVM_ABI bool verifyDescriptorRangeFlag(uint32_t Version,
                                         dxil::ResourceClass Type,
-                                        dxbc::DescriptorRangeFlags FlagsVal);
+                                        uint32_t FlagsVal);
 LLVM_ABI bool verifyStaticSamplerFlags(uint32_t Version, uint32_t FlagsNumber);
 LLVM_ABI bool verifyNumDescriptors(uint32_t NumDescriptors);
 LLVM_ABI bool verifyMipLODBias(float MipLODBias);
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
index 5785505ce2b0c..2a22364563f90 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
@@ -665,8 +665,7 @@ Error MetadataParser::validateRootSignature(
                              "NumDescriptors", Range.NumDescriptors));
 
         if (!hlsl::rootsig::verifyDescriptorRangeFlag(
-                RSD.Version, Range.RangeType,
-                dxbc::DescriptorRangeFlags(Range.Flags)))
+                RSD.Version, Range.RangeType, Range.Flags))
           DeferredErrs =
               joinErrors(std::move(DeferredErrs),
                          make_error<RootSignatureValidationError<uint32_t>>(
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
index 2c78d622f7f28..e887906955dd2 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
@@ -36,6 +36,11 @@ bool verifyRegisterSpace(uint32_t RegisterSpace) {
 
 bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) {
   using FlagT = dxbc::RootDescriptorFlags;
+  uint32_t LargestValue =
+      llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR);
+  if (FlagsVal >= NextPowerOf2(LargestValue))
+    return false;
+
   FlagT Flags = FlagT(FlagsVal);
   if (Version == 1)
     return Flags == FlagT::DataVolatile;
@@ -54,9 +59,14 @@ bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) {
 }
 
 bool verifyDescriptorRangeFlag(uint32_t Version, dxil::ResourceClass Type,
-                               dxbc::DescriptorRangeFlags Flags) {
+                               uint32_t FlagsVal) {
   using FlagT = dxbc::DescriptorRangeFlags;
+  uint32_t LargestValue =
+      llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR);
+  if (FlagsVal >= NextPowerOf2(LargestValue))
+    return false;
 
+  FlagT Flags = FlagT(FlagsVal);
   const bool IsSampler = (Type == dxil::ResourceClass::Sampler);
 
   if (Version == 1) {
diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-Invalid-Flag-LargeNumber.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-Invalid-Flag-LargeNumber.ll
new file mode 100644
index 0000000000000..c27c87ff057d5
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-Invalid-Flag-LargeNumber.ll
@@ -0,0 +1,20 @@
+; RUN: not opt -passes='print<dxil-root-signature>' %s -S -o - 2>&1 | FileCheck %s
+
+target triple = "dxil-unknown-shadermodel6.0-compute"
+
+; CHECK: error: Invalid value for DescriptorFlag: 66666
+; CHECK-NOT: Root Signature Definitions
+
+define void @main() #0 {
+entry:
+  ret void
+}
+attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
+
+
+!dx.rootsignatures = !{!2} ; list of function/root signature pairs
+!2 = !{ ptr @main, !3, i32 2 } ; function, root signature
+!3 = !{ !5 } ; list of root signature elements
+!5 = !{ !"DescriptorTable", i32 0, !6, !7 }
+!6 = !{ !"SRV", i32 1, i32 1, i32 0, i32 -1, i32 66666 }
+!7 = !{ !"UAV", i32 5, i32 1, i32 10, i32 5, i32 2 }
diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootDescriptor-Invalid-Flags-LargeNumber.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootDescriptor-Invalid-Flags-LargeNumber.ll
new file mode 100644
index 0000000000000..898e197c7e0cc
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootDescriptor-Invalid-Flags-LargeNumber.ll
@@ -0,0 +1,18 @@
+; RUN: not opt -passes='print<dxil-root-signature>' %s -S -o - 2>&1 | FileCheck %s
+
+target triple = "dxil-unknown-shadermodel6.0-compute"
+
+
+; CHECK: error: Invalid value for RootDescriptorFlag: 666
+; CHECK-NOT: Root Signature Definitions
+define void @main() #0 {
+entry:
+  ret void
+}
+attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
+
+
+!dx.rootsignatures = !{!2} ; list of function/root signature pairs
+!2 = !{ ptr @main, !3, i32 2 } ; function, root signature
+!3 = !{ !5 } ; list of root signature elements
+!5 = !{ !"RootCBV", i32 0, i32 1, i32 2, i32 666  }

``````````

</details>


https://github.com/llvm/llvm-project/pull/161587


More information about the llvm-commits mailing list