[llvm-branch-commits] [clang] [llvm] [HLSL][RootSignature] Add parsing for RootFlags (PR #138055)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Apr 30 16:58:05 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-clang
@llvm/pr-subscribers-hlsl
Author: Finn Plummer (inbelic)
<details>
<summary>Changes</summary>
- defines the `RootFlags` in-memory enum
- defines `parseRootFlags` to parse the various flag enums into a single `uint32_t`
- adds corresponding unit tests
- improves the diagnostic message for when we provide a non-zero integer value to the flags
Resolves https://github.com/llvm/llvm-project/issues/126575
---
Full diff: https://github.com/llvm/llvm-project/pull/138055.diff
7 Files Affected:
- (modified) clang/include/clang/Basic/DiagnosticParseKinds.td (+1)
- (modified) clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def (+19)
- (modified) clang/include/clang/Parse/ParseHLSLRootSignature.h (+1)
- (modified) clang/lib/Parse/ParseHLSLRootSignature.cpp (+63-10)
- (modified) clang/unittests/Lex/LexHLSLRootSignatureTest.cpp (+14-1)
- (modified) clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp (+51-1)
- (modified) llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h (+19-2)
``````````diff
diff --git a/clang/include/clang/Basic/DiagnosticParseKinds.td b/clang/include/clang/Basic/DiagnosticParseKinds.td
index 72e765bcb800d..75ed28f95cd32 100644
--- a/clang/include/clang/Basic/DiagnosticParseKinds.td
+++ b/clang/include/clang/Basic/DiagnosticParseKinds.td
@@ -1842,5 +1842,6 @@ def err_hlsl_unexpected_end_of_params
def err_hlsl_rootsig_repeat_param : Error<"specified the same parameter '%0' multiple times">;
def err_hlsl_rootsig_missing_param : Error<"did not specify mandatory parameter '%0'">;
def err_hlsl_number_literal_overflow : Error<"integer literal is too large to be represented as a 32-bit %select{signed |}0 integer type">;
+def err_hlsl_rootsig_non_zero_flag : Error<"non-zero integer literal specified for flag value">;
} // end of Parser diagnostics
diff --git a/clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def b/clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def
index ecb8cfc7afa16..eac6ebda84965 100644
--- a/clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def
+++ b/clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def
@@ -27,6 +27,9 @@
#endif
// Defines the various types of enum
+#ifndef ROOT_FLAG_ENUM
+#define ROOT_FLAG_ENUM(NAME, LIT) ENUM(NAME, LIT)
+#endif
#ifndef UNBOUNDED_ENUM
#define UNBOUNDED_ENUM(NAME, LIT) ENUM(NAME, LIT)
#endif
@@ -73,6 +76,7 @@ PUNCTUATOR(minus, '-')
// RootElement Keywords:
KEYWORD(RootSignature) // used only for diagnostic messaging
+KEYWORD(RootFlags)
KEYWORD(DescriptorTable)
KEYWORD(RootConstants)
@@ -100,6 +104,20 @@ UNBOUNDED_ENUM(unbounded, "unbounded")
// Descriptor Range Offset Enum:
DESCRIPTOR_RANGE_OFFSET_ENUM(DescriptorRangeOffsetAppend, "DESCRIPTOR_RANGE_OFFSET_APPEND")
+// Root Flag Enums:
+ROOT_FLAG_ENUM(AllowInputAssemblerInputLayout, "ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT")
+ROOT_FLAG_ENUM(DenyVertexShaderRootAccess, "DENY_VERTEX_SHADER_ROOT_ACCESS")
+ROOT_FLAG_ENUM(DenyHullShaderRootAccess, "DENY_HULL_SHADER_ROOT_ACCESS")
+ROOT_FLAG_ENUM(DenyDomainShaderRootAccess, "DENY_DOMAIN_SHADER_ROOT_ACCESS")
+ROOT_FLAG_ENUM(DenyGeometryShaderRootAccess, "DENY_GEOMETRY_SHADER_ROOT_ACCESS")
+ROOT_FLAG_ENUM(DenyPixelShaderRootAccess, "DENY_PIXEL_SHADER_ROOT_ACCESS")
+ROOT_FLAG_ENUM(DenyAmplificationShaderRootAccess, "DENY_AMPLIFICATION_SHADER_ROOT_ACCESS")
+ROOT_FLAG_ENUM(DenyMeshShaderRootAccess, "DENY_MESH_SHADER_ROOT_ACCESS")
+ROOT_FLAG_ENUM(AllowStreamOutput, "ALLOW_STREAM_OUTPUT")
+ROOT_FLAG_ENUM(LocalRootSignature, "LOCAL_ROOT_SIGNATURE")
+ROOT_FLAG_ENUM(CBVSRVUAVHeapDirectlyIndexed, "CBV_SRV_UAV_HEAP_DIRECTLY_INDEXED")
+ROOT_FLAG_ENUM(SamplerHeapDirectlyIndexed , "SAMPLER_HEAP_DIRECTLY_INDEXED")
+
// Root Descriptor Flag Enums:
ROOT_DESCRIPTOR_FLAG_ENUM(DataVolatile, "DATA_VOLATILE")
ROOT_DESCRIPTOR_FLAG_ENUM(DataStaticWhileSetAtExecute, "DATA_STATIC_WHILE_SET_AT_EXECUTE")
@@ -127,6 +145,7 @@ SHADER_VISIBILITY_ENUM(Mesh, "SHADER_VISIBILITY_MESH")
#undef DESCRIPTOR_RANGE_FLAG_ENUM_OFF
#undef DESCRIPTOR_RANGE_FLAG_ENUM_ON
#undef ROOT_DESCRIPTOR_FLAG_ENUM
+#undef ROOT_FLAG_ENUM
#undef DESCRIPTOR_RANGE_OFFSET_ENUM
#undef UNBOUNDED_ENUM
#undef ENUM
diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index 2ac2083983741..915266f8a36ae 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -71,6 +71,7 @@ class RootSignatureParser {
// expected, or, there is a lexing error
/// Root Element parse methods:
+ std::optional<llvm::hlsl::rootsig::RootFlags> parseRootFlags();
std::optional<llvm::hlsl::rootsig::RootConstants> parseRootConstants();
std::optional<llvm::hlsl::rootsig::DescriptorTable> parseDescriptorTable();
std::optional<llvm::hlsl::rootsig::DescriptorTableClause>
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index a5006b77a6e44..4780af0f94162 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -27,6 +27,13 @@ RootSignatureParser::RootSignatureParser(SmallVector<RootElement> &Elements,
bool RootSignatureParser::parse() {
// Iterate as many RootElements as possible
do {
+ if (tryConsumeExpectedToken(TokenKind::kw_RootFlags)) {
+ auto Flags = parseRootFlags();
+ if (!Flags.has_value())
+ return true;
+ Elements.push_back(*Flags);
+ }
+
if (tryConsumeExpectedToken(TokenKind::kw_RootConstants)) {
auto Constants = parseRootConstants();
if (!Constants.has_value())
@@ -47,6 +54,61 @@ bool RootSignatureParser::parse() {
/*param of=*/TokenKind::kw_RootSignature);
}
+template <typename FlagType>
+static FlagType maybeOrFlag(std::optional<FlagType> Flags, FlagType Flag) {
+ if (!Flags.has_value())
+ return Flag;
+
+ return static_cast<FlagType>(llvm::to_underlying(Flags.value()) |
+ llvm::to_underlying(Flag));
+}
+
+std::optional<RootFlags> RootSignatureParser::parseRootFlags() {
+ assert(CurToken.TokKind == TokenKind::kw_RootFlags &&
+ "Expects to only be invoked starting at given keyword");
+
+ if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
+ CurToken.TokKind))
+ return std::nullopt;
+
+ std::optional<RootFlags> Flags = RootFlags::None;
+
+ // Handle the edge-case of '0' to specify no flags set
+ if (tryConsumeExpectedToken(TokenKind::int_literal)) {
+ if (!verifyZeroFlag()) {
+ getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_non_zero_flag);
+ return std::nullopt;
+ }
+ } else {
+ // Otherwise, parse as many flags as possible
+ TokenKind Expected[] = {
+#define ROOT_FLAG_ENUM(NAME, LIT) TokenKind::en_##NAME,
+#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
+ };
+
+ do {
+ if (tryConsumeExpectedToken(Expected)) {
+ switch (CurToken.TokKind) {
+#define ROOT_FLAG_ENUM(NAME, LIT) \
+ case TokenKind::en_##NAME: \
+ Flags = maybeOrFlag<RootFlags>(Flags, RootFlags::NAME); \
+ break;
+#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
+ default:
+ llvm_unreachable("Switch for consumed enum token was not provided");
+ }
+ }
+ } while (tryConsumeExpectedToken(TokenKind::pu_or));
+ }
+
+ if (consumeExpectedToken(TokenKind::pu_r_paren,
+ diag::err_hlsl_unexpected_end_of_params,
+ /*param of=*/TokenKind::kw_RootFlags))
+ return std::nullopt;
+
+ return Flags;
+}
+
std::optional<RootConstants> RootSignatureParser::parseRootConstants() {
assert(CurToken.TokKind == TokenKind::kw_RootConstants &&
"Expects to only be invoked starting at given keyword");
@@ -467,15 +529,6 @@ RootSignatureParser::parseShaderVisibility() {
return std::nullopt;
}
-template <typename FlagType>
-static FlagType maybeOrFlag(std::optional<FlagType> Flags, FlagType Flag) {
- if (!Flags.has_value())
- return Flag;
-
- return static_cast<FlagType>(llvm::to_underlying(Flags.value()) |
- llvm::to_underlying(Flag));
-}
-
std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags>
RootSignatureParser::parseDescriptorRangeFlags() {
assert(CurToken.TokKind == TokenKind::pu_equal &&
@@ -484,7 +537,7 @@ RootSignatureParser::parseDescriptorRangeFlags() {
// Handle the edge-case of '0' to specify no flags set
if (tryConsumeExpectedToken(TokenKind::int_literal)) {
if (!verifyZeroFlag()) {
- getDiags().Report(CurToken.TokLoc, diag::err_expected) << "'0'";
+ getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_non_zero_flag);
return std::nullopt;
}
return DescriptorRangeFlags::None;
diff --git a/clang/unittests/Lex/LexHLSLRootSignatureTest.cpp b/clang/unittests/Lex/LexHLSLRootSignatureTest.cpp
index 89e9a3183ad03..21a1f1f08ae05 100644
--- a/clang/unittests/Lex/LexHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Lex/LexHLSLRootSignatureTest.cpp
@@ -87,7 +87,7 @@ TEST_F(LexHLSLRootSignatureTest, ValidLexAllTokensTest) {
RootSignature
- DescriptorTable RootConstants
+ RootFlags DescriptorTable RootConstants
num32BitConstants
@@ -98,6 +98,19 @@ TEST_F(LexHLSLRootSignatureTest, ValidLexAllTokensTest) {
unbounded
DESCRIPTOR_RANGE_OFFSET_APPEND
+ allow_input_assembler_input_layout
+ deny_vertex_shader_root_access
+ deny_hull_shader_root_access
+ deny_domain_shader_root_access
+ deny_geometry_shader_root_access
+ deny_pixel_shader_root_access
+ deny_amplification_shader_root_access
+ deny_mesh_shader_root_access
+ allow_stream_output
+ local_root_signature
+ cbv_srv_uav_heap_directly_indexed
+ sampler_heap_directly_indexed
+
DATA_VOLATILE
DATA_STATIC_WHILE_SET_AT_EXECUTE
DATA_STATIC
diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index 150eb3e6e54ef..18e1e517dae8f 100644
--- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -294,6 +294,56 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootConsantsTest) {
ASSERT_TRUE(Consumer->isSatisfied());
}
+TEST_F(ParseHLSLRootSignatureTest, ValidParseRootFlagsTest) {
+ const llvm::StringLiteral Source = R"cc(
+ RootFlags(),
+ RootFlags(0),
+ RootFlags(
+ deny_domain_shader_root_access |
+ deny_pixel_shader_root_access |
+ local_root_signature |
+ cbv_srv_uav_heap_directly_indexed |
+ deny_amplification_shader_root_access |
+ deny_geometry_shader_root_access |
+ deny_hull_shader_root_access |
+ deny_mesh_shader_root_access |
+ allow_stream_output |
+ sampler_heap_directly_indexed |
+ allow_input_assembler_input_layout |
+ deny_vertex_shader_root_access
+ )
+ )cc";
+
+ TrivialModuleLoader ModLoader;
+ auto PP = createPP(Source, ModLoader);
+ auto TokLoc = SourceLocation();
+
+ hlsl::RootSignatureLexer Lexer(Source, TokLoc);
+ SmallVector<RootElement> Elements;
+ hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
+
+ // Test no diagnostics produced
+ Consumer->setNoDiag();
+
+ ASSERT_FALSE(Parser.parse());
+
+ ASSERT_EQ(Elements.size(), 3u);
+
+ RootElement Elem = Elements[0];
+ ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
+ ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::None);
+
+ Elem = Elements[1];
+ ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
+ ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::None);
+
+ Elem = Elements[2];
+ ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
+ ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::ValidFlags);
+
+ ASSERT_TRUE(Consumer->isSatisfied());
+}
+
TEST_F(ParseHLSLRootSignatureTest, ValidTrailingCommaTest) {
// This test will checks we can handling trailing commas ','
const llvm::StringLiteral Source = R"cc(
@@ -496,7 +546,7 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidNonZeroFlagsTest) {
hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
// Test correct diagnostic produced
- Consumer->setExpected(diag::err_expected);
+ Consumer->setExpected(diag::err_hlsl_rootsig_non_zero_flag);
ASSERT_TRUE(Parser.parse());
ASSERT_TRUE(Consumer->isSatisfied());
diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
index 8b8324df18bb3..2ecaf69fc2f9c 100644
--- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
@@ -23,6 +23,23 @@ namespace rootsig {
// Definition of the various enumerations and flags
+enum class RootFlags : uint32_t {
+ None = 0,
+ AllowInputAssemblerInputLayout = 0x1,
+ DenyVertexShaderRootAccess = 0x2,
+ DenyHullShaderRootAccess = 0x4,
+ DenyDomainShaderRootAccess = 0x8,
+ DenyGeometryShaderRootAccess = 0x10,
+ DenyPixelShaderRootAccess = 0x20,
+ AllowStreamOutput = 0x40,
+ LocalRootSignature = 0x80,
+ DenyAmplificationShaderRootAccess = 0x100,
+ DenyMeshShaderRootAccess = 0x200,
+ CBVSRVUAVHeapDirectlyIndexed = 0x400,
+ SamplerHeapDirectlyIndexed = 0x800,
+ ValidFlags = 0x00000fff
+};
+
enum class DescriptorRangeFlags : unsigned {
None = 0,
DescriptorsVolatile = 0x1,
@@ -97,8 +114,8 @@ struct DescriptorTableClause {
};
// Models RootElement : RootConstants | DescriptorTable | DescriptorTableClause
-using RootElement =
- std::variant<RootConstants, DescriptorTable, DescriptorTableClause>;
+using RootElement = std::variant<RootFlags, RootConstants, DescriptorTable,
+ DescriptorTableClause>;
} // namespace rootsig
} // namespace hlsl
``````````
</details>
https://github.com/llvm/llvm-project/pull/138055
More information about the llvm-branch-commits
mailing list