[clang] fecf074 - [HLSL][RootSignature] Add parsing of DescriptorRangeFlags (#136775)
via cfe-commits
cfe-commits at lists.llvm.org
Fri Apr 25 13:05:33 PDT 2025
Author: Finn Plummer
Date: 2025-04-25T13:05:30-07:00
New Revision: fecf0742b16dc332c7a75b0a6696f08694943862
URL: https://github.com/llvm/llvm-project/commit/fecf0742b16dc332c7a75b0a6696f08694943862
DIFF: https://github.com/llvm/llvm-project/commit/fecf0742b16dc332c7a75b0a6696f08694943862.diff
LOG: [HLSL][RootSignature] Add parsing of DescriptorRangeFlags (#136775)
- Defines `parseDescriptorRangeFlags` to establish a pattern of how
flags will be parsed
- Add corresponding unit tests
Part four of implementing #126569
Added:
Modified:
clang/include/clang/Parse/ParseHLSLRootSignature.h
clang/lib/Parse/ParseHLSLRootSignature.cpp
clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
Removed:
################################################################################
diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index d639ca91c002f..d2e8f4dbcfc0c 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -81,6 +81,7 @@ class RootSignatureParser {
struct ParsedClauseParams {
std::optional<llvm::hlsl::rootsig::Register> Reg;
std::optional<uint32_t> Space;
+ std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags> Flags;
};
std::optional<ParsedClauseParams>
parseDescriptorTableClauseParams(RootSignatureToken::Kind RegType);
@@ -91,11 +92,19 @@ class RootSignatureParser {
/// Parsing methods of various enums
std::optional<llvm::hlsl::rootsig::ShaderVisibility> parseShaderVisibility();
+ std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags>
+ parseDescriptorRangeFlags();
/// Use NumericLiteralParser to convert CurToken.NumSpelling into a unsigned
/// 32-bit integer
std::optional<uint32_t> handleUIntLiteral();
+ /// Flags may specify the value of '0' to denote that there should be no
+ /// flags set.
+ ///
+ /// Return true if the current int_literal token is '0', otherwise false
+ bool verifyZeroFlag();
+
/// Invoke the Lexer to consume a token and update CurToken with the result
void consumeNextToken() { CurToken = Lexer.consumeToken(); }
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 8244e91c8f89a..3b9e96017c88d 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -129,6 +129,7 @@ RootSignatureParser::parseDescriptorTableClause() {
ExpectedReg = TokenKind::sReg;
break;
}
+ Clause.setDefaultFlags();
auto Params = parseDescriptorTableClauseParams(ExpectedReg);
if (!Params.has_value())
@@ -147,6 +148,9 @@ RootSignatureParser::parseDescriptorTableClause() {
if (Params->Space.has_value())
Clause.Space = Params->Space.value();
+ if (Params->Flags.has_value())
+ Clause.Flags = Params->Flags.value();
+
if (consumeExpectedToken(TokenKind::pu_r_paren,
diag::err_hlsl_unexpected_end_of_params,
/*param of=*/ParamKind))
@@ -194,6 +198,24 @@ RootSignatureParser::parseDescriptorTableClauseParams(TokenKind RegType) {
return std::nullopt;
Params.Space = Space;
}
+
+ // `flags` `=` DESCRIPTOR_RANGE_FLAGS
+ if (tryConsumeExpectedToken(TokenKind::kw_flags)) {
+ if (Params.Flags.has_value()) {
+ getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
+ << CurToken.TokKind;
+ return std::nullopt;
+ }
+
+ if (consumeExpectedToken(TokenKind::pu_equal))
+ return std::nullopt;
+
+ auto Flags = parseDescriptorRangeFlags();
+ if (!Flags.has_value())
+ return std::nullopt;
+ Params.Flags = Flags;
+ }
+
} while (tryConsumeExpectedToken(TokenKind::pu_comma));
return Params;
@@ -268,6 +290,54 @@ RootSignatureParser::parseShaderVisibility() {
return std::nullopt;
}
+template <typename FlagType>
+static FlagType maybeOrFlag(std::optional<FlagType> Flags, FlagType Flag) {
+ if (!Flags.has_value())
+ return Flag;
+
+ return static_cast<FlagType>(llvm::to_underlying(Flags.value()) |
+ llvm::to_underlying(Flag));
+}
+
+std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags>
+RootSignatureParser::parseDescriptorRangeFlags() {
+ assert(CurToken.TokKind == TokenKind::pu_equal &&
+ "Expects to only be invoked starting at given keyword");
+
+ // Handle the edge-case of '0' to specify no flags set
+ if (tryConsumeExpectedToken(TokenKind::int_literal)) {
+ if (!verifyZeroFlag()) {
+ getDiags().Report(CurToken.TokLoc, diag::err_expected) << "'0'";
+ return std::nullopt;
+ }
+ return DescriptorRangeFlags::None;
+ }
+
+ TokenKind Expected[] = {
+#define DESCRIPTOR_RANGE_FLAG_ENUM(NAME, LIT, ON) TokenKind::en_##NAME,
+#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
+ };
+
+ std::optional<DescriptorRangeFlags> Flags;
+
+ do {
+ if (tryConsumeExpectedToken(Expected)) {
+ switch (CurToken.TokKind) {
+#define DESCRIPTOR_RANGE_FLAG_ENUM(NAME, LIT, ON) \
+ case TokenKind::en_##NAME: \
+ Flags = \
+ maybeOrFlag<DescriptorRangeFlags>(Flags, DescriptorRangeFlags::NAME); \
+ break;
+#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
+ default:
+ llvm_unreachable("Switch for consumed enum token was not provided");
+ }
+ }
+ } while (tryConsumeExpectedToken(TokenKind::pu_or));
+
+ return Flags;
+}
+
std::optional<uint32_t> RootSignatureParser::handleUIntLiteral() {
// Parse the numeric value and do semantic checks on its specification
clang::NumericLiteralParser Literal(CurToken.NumSpelling, CurToken.TokLoc,
@@ -290,6 +360,12 @@ std::optional<uint32_t> RootSignatureParser::handleUIntLiteral() {
return Val.getExtValue();
}
+bool RootSignatureParser::verifyZeroFlag() {
+ assert(CurToken.TokKind == TokenKind::int_literal);
+ auto X = handleUIntLiteral();
+ return X.has_value() && X.value() == 0;
+}
+
bool RootSignatureParser::peekExpectedToken(TokenKind Expected) {
return peekExpectedToken(ArrayRef{Expected});
}
diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index 1d89567509e72..f4baf1580de61 100644
--- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -130,10 +130,14 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
const llvm::StringLiteral Source = R"cc(
DescriptorTable(
CBV(b0),
- SRV(space = 3, t42),
+ SRV(space = 3, t42, flags = 0),
visibility = SHADER_VISIBILITY_PIXEL,
Sampler(s987, space = +2),
- UAV(u4294967294)
+ UAV(u4294967294,
+ flags = Descriptors_Volatile | Data_Volatile
+ | Data_Static_While_Set_At_Execute | Data_Static
+ | Descriptors_Static_Keeping_Buffer_Bounds_Checks
+ )
),
DescriptorTable()
)cc";
@@ -159,6 +163,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
RegisterType::BReg);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 0u);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 0u);
+ ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
+ DescriptorRangeFlags::DataStaticWhileSetAtExecute);
Elem = Elements[1];
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -167,6 +173,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
RegisterType::TReg);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 42u);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 3u);
+ ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
+ DescriptorRangeFlags::None);
Elem = Elements[2];
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -175,6 +183,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
RegisterType::SReg);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 987u);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 2u);
+ ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
+ DescriptorRangeFlags::None);
Elem = Elements[3];
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -183,6 +193,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
RegisterType::UReg);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 4294967294u);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 0u);
+ ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
+ DescriptorRangeFlags::ValidFlags);
Elem = Elements[4];
ASSERT_TRUE(std::holds_alternative<DescriptorTable>(Elem));
@@ -199,6 +211,35 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
ASSERT_TRUE(Consumer->isSatisfied());
}
+TEST_F(ParseHLSLRootSignatureTest, ValidSamplerFlagsTest) {
+ // This test will checks we can set the valid enum for Sampler descriptor
+ // range flags
+ const llvm::StringLiteral Source = R"cc(
+ DescriptorTable(Sampler(s0, flags = DESCRIPTORS_VOLATILE))
+ )cc";
+
+ TrivialModuleLoader ModLoader;
+ auto PP = createPP(Source, ModLoader);
+ auto TokLoc = SourceLocation();
+
+ hlsl::RootSignatureLexer Lexer(Source, TokLoc);
+ SmallVector<RootElement> Elements;
+ hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
+
+ // Test no diagnostics produced
+ Consumer->setNoDiag();
+
+ ASSERT_FALSE(Parser.parse());
+
+ RootElement Elem = Elements[0];
+ ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
+ ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::Sampler);
+ ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
+ DescriptorRangeFlags::ValidSamplerFlags);
+
+ ASSERT_TRUE(Consumer->isSatisfied());
+}
+
TEST_F(ParseHLSLRootSignatureTest, ValidTrailingCommaTest) {
// This test will checks we can handling trailing commas ','
const llvm::StringLiteral Source = R"cc(
@@ -383,4 +424,28 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidLexOverflowedNumberTest) {
ASSERT_TRUE(Consumer->isSatisfied());
}
+TEST_F(ParseHLSLRootSignatureTest, InvalidNonZeroFlagsTest) {
+ // This test will check that parsing fails when a non-zero integer literal
+ // is given to flags
+ const llvm::StringLiteral Source = R"cc(
+ DescriptorTable(
+ CBV(b0, flags = 3)
+ )
+ )cc";
+
+ TrivialModuleLoader ModLoader;
+ auto PP = createPP(Source, ModLoader);
+ auto TokLoc = SourceLocation();
+
+ hlsl::RootSignatureLexer Lexer(Source, TokLoc);
+ SmallVector<RootElement> Elements;
+ hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
+
+ // Test correct diagnostic produced
+ Consumer->setExpected(diag::err_expected);
+ ASSERT_TRUE(Parser.parse());
+
+ ASSERT_TRUE(Consumer->isSatisfied());
+}
+
} // anonymous namespace
diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
index d51b853942dd3..0745bce983bb3 100644
--- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
@@ -23,6 +23,17 @@ namespace rootsig {
// Definition of the various enumerations and flags
+enum class DescriptorRangeFlags : unsigned {
+ None = 0,
+ DescriptorsVolatile = 0x1,
+ DataVolatile = 0x2,
+ DataStaticWhileSetAtExecute = 0x4,
+ DataStatic = 0x8,
+ DescriptorsStaticKeepingBufferBoundsChecks = 0x10000,
+ ValidFlags = 0x1000f,
+ ValidSamplerFlags = DescriptorsVolatile,
+};
+
enum class ShaderVisibility {
All = 0,
Vertex = 1,
@@ -55,6 +66,22 @@ struct DescriptorTableClause {
ClauseType Type;
Register Reg;
uint32_t Space = 0;
+ DescriptorRangeFlags Flags;
+
+ void setDefaultFlags() {
+ switch (Type) {
+ case ClauseType::CBuffer:
+ case ClauseType::SRV:
+ Flags = DescriptorRangeFlags::DataStaticWhileSetAtExecute;
+ break;
+ case ClauseType::UAV:
+ Flags = DescriptorRangeFlags::DataVolatile;
+ break;
+ case ClauseType::Sampler:
+ Flags = DescriptorRangeFlags::None;
+ break;
+ }
+ }
};
// Models RootElement : DescriptorTable | DescriptorTableClause
More information about the cfe-commits
mailing list