[clang] [llvm] [DirectX] Fix Flags validation to prevent casting into enum (PR #161587)
via cfe-commits
cfe-commits at lists.llvm.org
Mon Oct 6 13:29:51 PDT 2025
https://github.com/joaosaffran updated https://github.com/llvm/llvm-project/pull/161587
>From 82832bc1604b6677f25d96af028788c0f8648b15 Mon Sep 17 00:00:00 2001
From: Joao Saffran <joaosaffranllvm at gmail.com>
Date: Wed, 1 Oct 2025 12:23:38 -0700
Subject: [PATCH 1/2] fix validation logic
---
clang/lib/Sema/SemaHLSL.cpp | 4 ++--
.../Frontend/HLSL/RootSignatureValidations.h | 2 +-
.../Frontend/HLSL/RootSignatureMetadata.cpp | 3 +--
.../HLSL/RootSignatureValidations.cpp | 12 ++++++++++-
...escriptorTable-Invalid-Flag-LargeNumber.ll | 20 +++++++++++++++++++
...ootDescriptor-Invalid-Flags-LargeNumber.ll | 18 +++++++++++++++++
6 files changed, 53 insertions(+), 6 deletions(-)
create mode 100644 llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-Invalid-Flag-LargeNumber.ll
create mode 100644 llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootDescriptor-Invalid-Flags-LargeNumber.ll
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 129b03c07c0bd..a2e8afb9bb8ff 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1322,8 +1322,8 @@ bool SemaHLSL::handleRootSignatureElements(
ReportError(Loc, 1, 0xfffffffe);
}
- if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(Version, Clause->Type,
- Clause->Flags))
+ if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(
+ Version, Clause->Type, llvm::to_underlying(Clause->Flags)))
ReportFlagError(Loc);
}
}
diff --git a/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h b/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h
index 4dd18111b0c9d..10723a181f025 100644
--- a/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h
+++ b/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h
@@ -32,7 +32,7 @@ LLVM_ABI bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal);
LLVM_ABI bool verifyRangeType(uint32_t Type);
LLVM_ABI bool verifyDescriptorRangeFlag(uint32_t Version,
dxil::ResourceClass Type,
- dxbc::DescriptorRangeFlags FlagsVal);
+ uint32_t FlagsVal);
LLVM_ABI bool verifyStaticSamplerFlags(uint32_t Version, uint32_t FlagsNumber);
LLVM_ABI bool verifyNumDescriptors(uint32_t NumDescriptors);
LLVM_ABI bool verifyMipLODBias(float MipLODBias);
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
index 5785505ce2b0c..2a22364563f90 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
@@ -665,8 +665,7 @@ Error MetadataParser::validateRootSignature(
"NumDescriptors", Range.NumDescriptors));
if (!hlsl::rootsig::verifyDescriptorRangeFlag(
- RSD.Version, Range.RangeType,
- dxbc::DescriptorRangeFlags(Range.Flags)))
+ RSD.Version, Range.RangeType, Range.Flags))
DeferredErrs =
joinErrors(std::move(DeferredErrs),
make_error<RootSignatureValidationError<uint32_t>>(
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
index 2c78d622f7f28..e887906955dd2 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
@@ -36,6 +36,11 @@ bool verifyRegisterSpace(uint32_t RegisterSpace) {
bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) {
using FlagT = dxbc::RootDescriptorFlags;
+ uint32_t LargestValue =
+ llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR);
+ if (FlagsVal >= NextPowerOf2(LargestValue))
+ return false;
+
FlagT Flags = FlagT(FlagsVal);
if (Version == 1)
return Flags == FlagT::DataVolatile;
@@ -54,9 +59,14 @@ bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) {
}
bool verifyDescriptorRangeFlag(uint32_t Version, dxil::ResourceClass Type,
- dxbc::DescriptorRangeFlags Flags) {
+ uint32_t FlagsVal) {
using FlagT = dxbc::DescriptorRangeFlags;
+ uint32_t LargestValue =
+ llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR);
+ if (FlagsVal >= NextPowerOf2(LargestValue))
+ return false;
+ FlagT Flags = FlagT(FlagsVal);
const bool IsSampler = (Type == dxil::ResourceClass::Sampler);
if (Version == 1) {
diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-Invalid-Flag-LargeNumber.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-Invalid-Flag-LargeNumber.ll
new file mode 100644
index 0000000000000..c27c87ff057d5
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-Invalid-Flag-LargeNumber.ll
@@ -0,0 +1,20 @@
+; RUN: not opt -passes='print<dxil-root-signature>' %s -S -o - 2>&1 | FileCheck %s
+
+target triple = "dxil-unknown-shadermodel6.0-compute"
+
+; CHECK: error: Invalid value for DescriptorFlag: 66666
+; CHECK-NOT: Root Signature Definitions
+
+define void @main() #0 {
+entry:
+ ret void
+}
+attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
+
+
+!dx.rootsignatures = !{!2} ; list of function/root signature pairs
+!2 = !{ ptr @main, !3, i32 2 } ; function, root signature
+!3 = !{ !5 } ; list of root signature elements
+!5 = !{ !"DescriptorTable", i32 0, !6, !7 }
+!6 = !{ !"SRV", i32 1, i32 1, i32 0, i32 -1, i32 66666 }
+!7 = !{ !"UAV", i32 5, i32 1, i32 10, i32 5, i32 2 }
diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootDescriptor-Invalid-Flags-LargeNumber.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootDescriptor-Invalid-Flags-LargeNumber.ll
new file mode 100644
index 0000000000000..898e197c7e0cc
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootDescriptor-Invalid-Flags-LargeNumber.ll
@@ -0,0 +1,18 @@
+; RUN: not opt -passes='print<dxil-root-signature>' %s -S -o - 2>&1 | FileCheck %s
+
+target triple = "dxil-unknown-shadermodel6.0-compute"
+
+
+; CHECK: error: Invalid value for RootDescriptorFlag: 666
+; CHECK-NOT: Root Signature Definitions
+define void @main() #0 {
+entry:
+ ret void
+}
+attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
+
+
+!dx.rootsignatures = !{!2} ; list of function/root signature pairs
+!2 = !{ ptr @main, !3, i32 2 } ; function, root signature
+!3 = !{ !5 } ; list of root signature elements
+!5 = !{ !"RootCBV", i32 0, i32 1, i32 2, i32 666 }
>From 23bea276768244273ac50fbeb366916a81569ab4 Mon Sep 17 00:00:00 2001
From: Joao Saffran <joaosaffranllvm at gmail.com>
Date: Mon, 6 Oct 2025 13:29:30 -0700
Subject: [PATCH 2/2] addressing comment from bogner
---
clang/lib/Sema/SemaHLSL.cpp | 8 +++---
llvm/include/llvm/BinaryFormat/DXContainer.h | 6 +++++
.../Frontend/HLSL/RootSignatureValidations.h | 8 +++---
llvm/lib/BinaryFormat/DXContainer.cpp | 21 ++++++++++++++++
.../Frontend/HLSL/RootSignatureMetadata.cpp | 21 +++++++++++-----
.../HLSL/RootSignatureValidations.cpp | 25 ++++---------------
6 files changed, 56 insertions(+), 33 deletions(-)
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index b392ec648598f..a662b72c2a362 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1289,8 +1289,8 @@ bool SemaHLSL::handleRootSignatureElements(
VerifyRegister(Loc, Descriptor->Reg.Number);
VerifySpace(Loc, Descriptor->Space);
- if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(
- Version, llvm::to_underlying(Descriptor->Flags)))
+ if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(Version,
+ Descriptor->Flags))
ReportFlagError(Loc);
} else if (const auto *Constants =
std::get_if<llvm::hlsl::rootsig::RootConstants>(&Elem)) {
@@ -1322,8 +1322,8 @@ bool SemaHLSL::handleRootSignatureElements(
ReportError(Loc, 1, 0xfffffffe);
}
- if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(
- Version, Clause->Type, llvm::to_underlying(Clause->Flags)))
+ if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(Version, Clause->Type,
+ Clause->Flags))
ReportFlagError(Loc);
}
}
diff --git a/llvm/include/llvm/BinaryFormat/DXContainer.h b/llvm/include/llvm/BinaryFormat/DXContainer.h
index 0b5646229e8b5..b9a08ce1ca14e 100644
--- a/llvm/include/llvm/BinaryFormat/DXContainer.h
+++ b/llvm/include/llvm/BinaryFormat/DXContainer.h
@@ -248,6 +248,12 @@ enum class StaticBorderColor : uint32_t {
bool isValidBorderColor(uint32_t V);
+bool isValidRootDesciptorFlags(uint32_t V);
+
+bool isValidDescriptorRangeFlags(uint32_t V);
+
+bool isValidStaticSamplerFlags(uint32_t V);
+
LLVM_ABI ArrayRef<EnumEntry<StaticBorderColor>> getStaticBorderColors();
LLVM_ABI PartType parsePartType(StringRef S);
diff --git a/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h b/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h
index 10723a181f025..7131980e9ff3a 100644
--- a/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h
+++ b/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h
@@ -28,12 +28,14 @@ LLVM_ABI bool verifyRootFlag(uint32_t Flags);
LLVM_ABI bool verifyVersion(uint32_t Version);
LLVM_ABI bool verifyRegisterValue(uint32_t RegisterValue);
LLVM_ABI bool verifyRegisterSpace(uint32_t RegisterSpace);
-LLVM_ABI bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal);
+LLVM_ABI bool verifyRootDescriptorFlag(uint32_t Version,
+ dxbc::RootDescriptorFlags Flags);
LLVM_ABI bool verifyRangeType(uint32_t Type);
LLVM_ABI bool verifyDescriptorRangeFlag(uint32_t Version,
dxil::ResourceClass Type,
- uint32_t FlagsVal);
-LLVM_ABI bool verifyStaticSamplerFlags(uint32_t Version, uint32_t FlagsNumber);
+ dxbc::DescriptorRangeFlags Flags);
+LLVM_ABI bool verifyStaticSamplerFlags(uint32_t Version,
+ dxbc::StaticSamplerFlags Flags);
LLVM_ABI bool verifyNumDescriptors(uint32_t NumDescriptors);
LLVM_ABI bool verifyMipLODBias(float MipLODBias);
LLVM_ABI bool verifyMaxAnisotropy(uint32_t MaxAnisotropy);
diff --git a/llvm/lib/BinaryFormat/DXContainer.cpp b/llvm/lib/BinaryFormat/DXContainer.cpp
index b334f86568acb..22f518067b318 100644
--- a/llvm/lib/BinaryFormat/DXContainer.cpp
+++ b/llvm/lib/BinaryFormat/DXContainer.cpp
@@ -82,6 +82,27 @@ bool llvm::dxbc::isValidBorderColor(uint32_t V) {
return false;
}
+bool llvm::dxbc::isValidRootDesciptorFlags(uint32_t V) {
+ using FlagT = dxbc::RootDescriptorFlags;
+ uint32_t LargestValue =
+ llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR);
+ return V < NextPowerOf2(LargestValue);
+}
+
+bool llvm::dxbc::isValidDescriptorRangeFlags(uint32_t V) {
+ using FlagT = dxbc::DescriptorRangeFlags;
+ uint32_t LargestValue =
+ llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR);
+ return V < NextPowerOf2(LargestValue);
+}
+
+bool llvm::dxbc::isValidStaticSamplerFlags(uint32_t V) {
+ using FlagT = dxbc::StaticSamplerFlags;
+ uint32_t LargestValue =
+ llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR);
+ return V < NextPowerOf2(LargestValue);
+}
+
dxbc::PartType dxbc::parsePartType(StringRef S) {
#define CONTAINER_PART(PartName) .Case(#PartName, PartType::PartName)
return StringSwitch<dxbc::PartType>(S)
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
index b36d1234a6774..707f0c368e9d8 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
@@ -651,8 +651,11 @@ Error MetadataParser::validateRootSignature(
"RegisterSpace", Descriptor.RegisterSpace));
if (RSD.Version > 1) {
- if (!hlsl::rootsig::verifyRootDescriptorFlag(RSD.Version,
- Descriptor.Flags))
+ bool IsValidFlag =
+ dxbc::isValidRootDesciptorFlags(Descriptor.Flags) &&
+ hlsl::rootsig::verifyRootDescriptorFlag(
+ RSD.Version, dxbc::RootDescriptorFlags(Descriptor.Flags));
+ if (!IsValidFlag)
DeferredErrs =
joinErrors(std::move(DeferredErrs),
make_error<RootSignatureValidationError<uint32_t>>(
@@ -676,8 +679,11 @@ Error MetadataParser::validateRootSignature(
make_error<RootSignatureValidationError<uint32_t>>(
"NumDescriptors", Range.NumDescriptors));
- if (!hlsl::rootsig::verifyDescriptorRangeFlag(
- RSD.Version, Range.RangeType, Range.Flags))
+ bool IsValidFlag = dxbc::isValidDescriptorRangeFlags(Range.Flags) &&
+ hlsl::rootsig::verifyDescriptorRangeFlag(
+ RSD.Version, Range.RangeType,
+ dxbc::DescriptorRangeFlags(Range.Flags));
+ if (!IsValidFlag)
DeferredErrs =
joinErrors(std::move(DeferredErrs),
make_error<RootSignatureValidationError<uint32_t>>(
@@ -730,8 +736,11 @@ Error MetadataParser::validateRootSignature(
joinErrors(std::move(DeferredErrs),
make_error<RootSignatureValidationError<uint32_t>>(
"RegisterSpace", Sampler.RegisterSpace));
-
- if (!hlsl::rootsig::verifyStaticSamplerFlags(RSD.Version, Sampler.Flags))
+ bool IsValidFlag =
+ dxbc::isValidStaticSamplerFlags(Sampler.Flags) &&
+ hlsl::rootsig::verifyStaticSamplerFlags(
+ RSD.Version, dxbc::StaticSamplerFlags(Sampler.Flags));
+ if (!IsValidFlag)
DeferredErrs =
joinErrors(std::move(DeferredErrs),
make_error<RootSignatureValidationError<uint32_t>>(
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
index 47a73060924b0..30408dfda940d 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
@@ -34,13 +34,9 @@ bool verifyRegisterSpace(uint32_t RegisterSpace) {
return !(RegisterSpace >= 0xFFFFFFF0);
}
-bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) {
+bool verifyRootDescriptorFlag(uint32_t Version,
+ dxbc::RootDescriptorFlags FlagsVal) {
using FlagT = dxbc::RootDescriptorFlags;
- uint32_t LargestValue =
- llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR);
- if (FlagsVal >= NextPowerOf2(LargestValue))
- return false;
-
FlagT Flags = FlagT(FlagsVal);
if (Version == 1)
return Flags == FlagT::DataVolatile;
@@ -59,14 +55,8 @@ bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) {
}
bool verifyDescriptorRangeFlag(uint32_t Version, dxil::ResourceClass Type,
- uint32_t FlagsVal) {
+ dxbc::DescriptorRangeFlags Flags) {
using FlagT = dxbc::DescriptorRangeFlags;
- uint32_t LargestValue =
- llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR);
- if (FlagsVal >= NextPowerOf2(LargestValue))
- return false;
-
- FlagT Flags = FlagT(FlagsVal);
const bool IsSampler = (Type == dxil::ResourceClass::Sampler);
if (Version == 1) {
@@ -123,13 +113,8 @@ bool verifyDescriptorRangeFlag(uint32_t Version, dxil::ResourceClass Type,
return (Flags & ~Mask) == FlagT::None;
}
-bool verifyStaticSamplerFlags(uint32_t Version, uint32_t FlagsNumber) {
- uint32_t LargestValue = llvm::to_underlying(
- dxbc::StaticSamplerFlags::LLVM_BITMASK_LARGEST_ENUMERATOR);
- if (FlagsNumber >= NextPowerOf2(LargestValue))
- return false;
-
- dxbc::StaticSamplerFlags Flags = dxbc::StaticSamplerFlags(FlagsNumber);
+bool verifyStaticSamplerFlags(uint32_t Version,
+ dxbc::StaticSamplerFlags Flags) {
if (Version <= 2)
return Flags == dxbc::StaticSamplerFlags::None;
More information about the cfe-commits
mailing list