[llvm-branch-commits] [clang] [llvm] [HLSL][RootSignature] Add infastructure to parse parameters (PR #133800)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Mar 31 14:22:01 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-hlsl
Author: Finn Plummer (inbelic)
<details>
<summary>Changes</summary>
- defines `ParamType` as a way to represent a reference to some
parameter in a root signature
- defines `ParseParam` and `ParseParams` as an infastructure to define
how the parameters of a given struct should be parsed in an orderless
manner
- implements parsing of two param types: `UInt32` and `Register` to
demonstrate the parsing implementation and allow for unit testing
Part two of implementing: https://github.com/llvm/llvm-project/issues/126569
---
Full diff: https://github.com/llvm/llvm-project/pull/133800.diff
5 Files Affected:
- (modified) clang/include/clang/Basic/DiagnosticParseKinds.td (+4-1)
- (modified) clang/include/clang/Parse/ParseHLSLRootSignature.h (+40)
- (modified) clang/lib/Parse/ParseHLSLRootSignature.cpp (+151-14)
- (modified) clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp (+142-4)
- (modified) llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h (+15)
``````````diff
diff --git a/clang/include/clang/Basic/DiagnosticParseKinds.td b/clang/include/clang/Basic/DiagnosticParseKinds.td
index 2582e1e5ef0f6..ab12159ba5ae1 100644
--- a/clang/include/clang/Basic/DiagnosticParseKinds.td
+++ b/clang/include/clang/Basic/DiagnosticParseKinds.td
@@ -1830,8 +1830,11 @@ def err_hlsl_virtual_function
def err_hlsl_virtual_inheritance
: Error<"virtual inheritance is unsupported in HLSL">;
-// HLSL Root Siganture diagnostic messages
+// HLSL Root Signature Parser Diagnostics
def err_hlsl_unexpected_end_of_params
: Error<"expected %0 to denote end of parameters, or, another valid parameter of %1">;
+def err_hlsl_rootsig_repeat_param : Error<"specified the same parameter '%0' multiple times">;
+def err_hlsl_rootsig_missing_param : Error<"did not specify mandatory parameter '%0'">;
+def err_hlsl_number_literal_overflow : Error<"integer literal is too large to be represented as a 32-bit %select{signed |}0 integer type">;
} // end of Parser diagnostics
diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index 43b41315b88b5..02e99e83875db 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -69,6 +69,46 @@ class RootSignatureParser {
bool parseDescriptorTable();
bool parseDescriptorTableClause();
+ /// Each unique ParamType will have a custom parse method defined that we can
+ /// use to invoke the parameters.
+ ///
+ /// This function will switch on the ParamType using std::visit and dispatch
+ /// onto the corresponding parse method
+ bool parseParam(llvm::hlsl::rootsig::ParamType Ref);
+
+ /// Parameter arguments (eg. `bReg`, `space`, ...) can be specified in any
+ /// order, exactly once, and only a subset are mandatory. This function acts
+ /// as the infastructure to do so in a declarative way.
+ ///
+ /// For the example:
+ /// SmallDenseMap<TokenKind, ParamType> Params = {
+ /// TokenKind::bReg, &Clause.Register,
+ /// TokenKind::kw_space, &Clause.Space
+ /// };
+ /// SmallDenseSet<TokenKind> Mandatory = {
+ /// TokenKind::kw_numDescriptors
+ /// };
+ ///
+ /// We can read it is as:
+ ///
+ /// when 'b0' is encountered, invoke the parse method for the type
+ /// of &Clause.Register (Register *) and update the parameter
+ /// when 'space' is encountered, invoke a parse method for the type
+ /// of &Clause.Space (uint32_t *) and update the parameter
+ ///
+ /// and 'bReg' must be specified
+ bool parseParams(
+ llvm::SmallDenseMap<TokenKind, llvm::hlsl::rootsig::ParamType> &Params,
+ llvm::SmallDenseSet<TokenKind> &Mandatory);
+
+ /// Parameter parse methods corresponding to a ParamType
+ bool parseUIntParam(uint32_t *X);
+ bool parseRegister(llvm::hlsl::rootsig::Register *Reg);
+
+ /// Use NumericLiteralParser to convert CurToken.NumSpelling into a unsigned
+ /// 32-bit integer
+ bool handleUIntLiteral(uint32_t *X);
+
/// Invoke the Lexer to consume a token and update CurToken with the result
void consumeNextToken() { CurToken = Lexer.ConsumeToken(); }
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 33caca5fa1c82..62d29baea49d3 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -8,6 +8,8 @@
#include "clang/Parse/ParseHLSLRootSignature.h"
+#include "clang/Lex/LiteralSupport.h"
+
#include "llvm/Support/raw_ostream.h"
using namespace llvm::hlsl::rootsig;
@@ -39,12 +41,11 @@ bool RootSignatureParser::parse() {
break;
}
- if (!tryConsumeExpectedToken(TokenKind::end_of_stream)) {
- getDiags().Report(CurToken.TokLoc, diag::err_hlsl_unexpected_end_of_params)
- << /*expected=*/TokenKind::end_of_stream
- << /*param of=*/TokenKind::kw_RootSignature;
+ if (consumeExpectedToken(TokenKind::end_of_stream,
+ diag::err_hlsl_unexpected_end_of_params,
+ /*param of=*/TokenKind::kw_RootSignature))
return true;
- }
+
return false;
}
@@ -70,12 +71,10 @@ bool RootSignatureParser::parseDescriptorTable() {
break;
}
- if (!tryConsumeExpectedToken(TokenKind::pu_r_paren)) {
- getDiags().Report(CurToken.TokLoc, diag::err_hlsl_unexpected_end_of_params)
- << /*expected=*/TokenKind::pu_r_paren
- << /*param of=*/TokenKind::kw_DescriptorTable;
+ if (consumeExpectedToken(TokenKind::pu_r_paren,
+ diag::err_hlsl_unexpected_end_of_params,
+ /*param of=*/TokenKind::kw_DescriptorTable))
return true;
- }
Elements.push_back(Table);
return false;
@@ -87,37 +86,174 @@ bool RootSignatureParser::parseDescriptorTableClause() {
CurToken.Kind == TokenKind::kw_UAV ||
CurToken.Kind == TokenKind::kw_Sampler) &&
"Expects to only be invoked starting at given keyword");
+ TokenKind ParamKind = CurToken.Kind; // retain for diagnostics
DescriptorTableClause Clause;
- switch (CurToken.Kind) {
+ TokenKind ExpectedRegister;
+ switch (ParamKind) {
default:
llvm_unreachable("Switch for consumed token was not provided");
case TokenKind::kw_CBV:
Clause.Type = ClauseType::CBuffer;
+ ExpectedRegister = TokenKind::bReg;
break;
case TokenKind::kw_SRV:
Clause.Type = ClauseType::SRV;
+ ExpectedRegister = TokenKind::tReg;
break;
case TokenKind::kw_UAV:
Clause.Type = ClauseType::UAV;
+ ExpectedRegister = TokenKind::uReg;
break;
case TokenKind::kw_Sampler:
Clause.Type = ClauseType::Sampler;
+ ExpectedRegister = TokenKind::sReg;
break;
}
if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
- CurToken.Kind))
+ ParamKind))
return true;
- if (consumeExpectedToken(TokenKind::pu_r_paren, diag::err_expected_after,
- CurToken.Kind))
+ llvm::SmallDenseMap<TokenKind, ParamType> Params = {
+ {ExpectedRegister, &Clause.Register},
+ {TokenKind::kw_space, &Clause.Space},
+ };
+ llvm::SmallDenseSet<TokenKind> Mandatory = {
+ ExpectedRegister,
+ };
+
+ if (parseParams(Params, Mandatory))
+ return true;
+
+ if (consumeExpectedToken(TokenKind::pu_r_paren,
+ diag::err_hlsl_unexpected_end_of_params,
+ /*param of=*/ParamKind))
return true;
Elements.push_back(Clause);
return false;
}
+// Helper struct defined to use the overloaded notation of std::visit.
+template <class... Ts> struct ParseMethods : Ts... { using Ts::operator()...; };
+template <class... Ts> ParseMethods(Ts...) -> ParseMethods<Ts...>;
+
+bool RootSignatureParser::parseParam(ParamType Ref) {
+ return std::visit(
+ ParseMethods{
+ [this](Register *X) -> bool { return parseRegister(X); },
+ [this](uint32_t *X) -> bool {
+ return consumeExpectedToken(TokenKind::pu_equal,
+ diag::err_expected_after,
+ CurToken.Kind) ||
+ parseUIntParam(X);
+ },
+ },
+ Ref);
+}
+
+bool RootSignatureParser::parseParams(
+ llvm::SmallDenseMap<TokenKind, ParamType> &Params,
+ llvm::SmallDenseSet<TokenKind> &Mandatory) {
+
+ // Initialize a vector of possible keywords
+ SmallVector<TokenKind> Keywords;
+ for (auto Pair : Params)
+ Keywords.push_back(Pair.first);
+
+ // Keep track of which keywords have been seen to report duplicates
+ llvm::SmallDenseSet<TokenKind> Seen;
+
+ while (tryConsumeExpectedToken(Keywords)) {
+ if (Seen.contains(CurToken.Kind)) {
+ getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
+ << CurToken.Kind;
+ return true;
+ }
+ Seen.insert(CurToken.Kind);
+
+ if (parseParam(Params[CurToken.Kind]))
+ return true;
+
+ if (!tryConsumeExpectedToken(TokenKind::pu_comma))
+ break;
+ }
+
+ bool AllMandatoryDefined = true;
+ for (auto Kind : Mandatory) {
+ bool SeenParam = Seen.contains(Kind);
+ if (!SeenParam) {
+ getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_missing_param)
+ << Kind;
+ }
+ AllMandatoryDefined &= SeenParam;
+ }
+
+ return !AllMandatoryDefined;
+}
+
+bool RootSignatureParser::parseUIntParam(uint32_t *X) {
+ assert(CurToken.Kind == TokenKind::pu_equal &&
+ "Expects to only be invoked starting at given keyword");
+ tryConsumeExpectedToken(TokenKind::pu_plus);
+ return consumeExpectedToken(TokenKind::int_literal, diag::err_expected_after,
+ CurToken.Kind) ||
+ handleUIntLiteral(X);
+}
+
+bool RootSignatureParser::parseRegister(Register *Register) {
+ assert(
+ (CurToken.Kind == TokenKind::bReg || CurToken.Kind == TokenKind::tReg ||
+ CurToken.Kind == TokenKind::uReg || CurToken.Kind == TokenKind::sReg) &&
+ "Expects to only be invoked starting at given keyword");
+
+ switch (CurToken.Kind) {
+ case TokenKind::bReg:
+ Register->ViewType = RegisterType::BReg;
+ break;
+ case TokenKind::tReg:
+ Register->ViewType = RegisterType::TReg;
+ break;
+ case TokenKind::uReg:
+ Register->ViewType = RegisterType::UReg;
+ break;
+ case TokenKind::sReg:
+ Register->ViewType = RegisterType::SReg;
+ break;
+ default:
+ break; // Unreachable given Try + assert pattern
+ }
+
+ if (handleUIntLiteral(&Register->Number))
+ return true; // propogate NumericLiteralParser error
+
+ return false;
+}
+
+bool RootSignatureParser::handleUIntLiteral(uint32_t *X) {
+ // Parse the numeric value and do semantic checks on its specification
+ clang::NumericLiteralParser Literal(CurToken.NumSpelling, CurToken.TokLoc,
+ PP.getSourceManager(), PP.getLangOpts(),
+ PP.getTargetInfo(), PP.getDiagnostics());
+ if (Literal.hadError)
+ return true; // Error has already been reported so just return
+
+ assert(Literal.isIntegerLiteral() && "IsNumberChar will only support digits");
+
+ llvm::APSInt Val = llvm::APSInt(32, false);
+ if (Literal.GetIntegerValue(Val)) {
+ // Report that the value has overflowed
+ PP.getDiagnostics().Report(CurToken.TokLoc,
+ diag::err_hlsl_number_literal_overflow)
+ << 0 << CurToken.NumSpelling;
+ return true;
+ }
+
+ *X = Val.getExtValue();
+ return false;
+}
+
bool RootSignatureParser::peekExpectedToken(TokenKind Expected) {
return peekExpectedToken(ArrayRef{Expected});
}
@@ -139,6 +275,7 @@ bool RootSignatureParser::consumeExpectedToken(TokenKind Expected,
case diag::err_expected:
DB << Expected;
break;
+ case diag::err_hlsl_unexpected_end_of_params:
case diag::err_expected_either:
case diag::err_expected_after:
DB << Expected << Context;
diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index acdf455a5d6aa..5b162e9a1b4cd 100644
--- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -130,10 +130,10 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseEmptyTest) {
TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
const llvm::StringLiteral Source = R"cc(
DescriptorTable(
- CBV(),
- SRV(),
- Sampler(),
- UAV()
+ CBV(b0),
+ SRV(space = 3, t42),
+ Sampler(s987, space = +2),
+ UAV(u4294967294)
),
DescriptorTable()
)cc";
@@ -155,18 +155,34 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
RootElement Elem = Elements[0];
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::CBuffer);
+ ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.ViewType,
+ RegisterType::BReg);
+ ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number, 0u);
+ ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 0u);
Elem = Elements[1];
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::SRV);
+ ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.ViewType,
+ RegisterType::TReg);
+ ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number, 42u);
+ ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 3u);
Elem = Elements[2];
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::Sampler);
+ ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.ViewType,
+ RegisterType::SReg);
+ ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number, 987u);
+ ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 2u);
Elem = Elements[3];
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::UAV);
+ ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.ViewType,
+ RegisterType::UReg);
+ ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number, 4294967294u);
+ ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 0u);
Elem = Elements[4];
ASSERT_TRUE(std::holds_alternative<DescriptorTable>(Elem));
@@ -176,6 +192,32 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
Elem = Elements[5];
ASSERT_TRUE(std::holds_alternative<DescriptorTable>(Elem));
ASSERT_EQ(std::get<DescriptorTable>(Elem).NumClauses, 0u);
+
+ ASSERT_TRUE(Consumer->isSatisfied());
+}
+
+TEST_F(ParseHLSLRootSignatureTest, ValidTrailingCommaTest) {
+ // This test will checks we can handling trailing commas ','
+ const llvm::StringLiteral Source = R"cc(
+ DescriptorTable(
+ CBV(b0, ),
+ SRV(t42),
+ )
+ )cc";
+
+ TrivialModuleLoader ModLoader;
+ auto PP = createPP(Source, ModLoader);
+ auto TokLoc = SourceLocation();
+
+ hlsl::RootSignatureLexer Lexer(Source, TokLoc);
+ SmallVector<RootElement> Elements;
+ hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
+
+ // Test no diagnostics produced
+ Consumer->setNoDiag();
+
+ ASSERT_FALSE(Parser.parse());
+
ASSERT_TRUE(Consumer->isSatisfied());
}
@@ -237,6 +279,102 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidParseUnexpectedEndOfStreamTest) {
// Test correct diagnostic produced - end of stream
Consumer->setExpected(diag::err_expected_after);
+
+ ASSERT_TRUE(Parser.parse());
+
+ ASSERT_TRUE(Consumer->isSatisfied());
+}
+
+TEST_F(ParseHLSLRootSignatureTest, InvalidMissingParameterTest) {
+ // This test will check that the parsing fails due a mandatory
+ // parameter (register) not being specified
+ const llvm::StringLiteral Source = R"cc(
+ DescriptorTable(
+ CBV()
+ )
+ )cc";
+
+ TrivialModuleLoader ModLoader;
+ auto PP = createPP(Source, ModLoader);
+ auto TokLoc = SourceLocation();
+
+ hlsl::RootSignatureLexer Lexer(Source, TokLoc);
+ SmallVector<RootElement> Elements;
+ hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
+
+ // Test correct diagnostic produced
+ Consumer->setExpected(diag::err_hlsl_rootsig_missing_param);
+ ASSERT_TRUE(Parser.parse());
+
+ ASSERT_TRUE(Consumer->isSatisfied());
+}
+
+TEST_F(ParseHLSLRootSignatureTest, InvalidRepeatedMandatoryParameterTest) {
+ // This test will check that the parsing fails due the same mandatory
+ // parameter being specified multiple times
+ const llvm::StringLiteral Source = R"cc(
+ DescriptorTable(
+ CBV(b32, b84)
+ )
+ )cc";
+
+ TrivialModuleLoader ModLoader;
+ auto PP = createPP(Source, ModLoader);
+ auto TokLoc = SourceLocation();
+
+ hlsl::RootSignatureLexer Lexer(Source, TokLoc);
+ SmallVector<RootElement> Elements;
+ hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
+
+ // Test correct diagnostic produced
+ Consumer->setExpected(diag::err_hlsl_rootsig_repeat_param);
+ ASSERT_TRUE(Parser.parse());
+
+ ASSERT_TRUE(Consumer->isSatisfied());
+}
+
+TEST_F(ParseHLSLRootSignatureTest, InvalidRepeatedOptionalParameterTest) {
+ // This test will check that the parsing fails due the same optional
+ // parameter being specified multiple times
+ const llvm::StringLiteral Source = R"cc(
+ DescriptorTable(
+ CBV(space = 2, space = 0)
+ )
+ )cc";
+
+ TrivialModuleLoader ModLoader;
+ auto PP = createPP(Source, ModLoader);
+ auto TokLoc = SourceLocation();
+
+ hlsl::RootSignatureLexer Lexer(Source, TokLoc);
+ SmallVector<RootElement> Elements;
+ hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
+
+ // Test correct diagnostic produced
+ Consumer->setExpected(diag::err_hlsl_rootsig_repeat_param);
+ ASSERT_TRUE(Parser.parse());
+
+ ASSERT_TRUE(Consumer->isSatisfied());
+}
+
+TEST_F(ParseHLSLRootSignatureTest, InvalidLexOverflowedNumberTest) {
+ // This test will check that the lexing fails due to an integer overflow
+ const llvm::StringLiteral Source = R"cc(
+ DescriptorTable(
+ CBV(b4294967296)
+ )
+ )cc";
+
+ TrivialModuleLoader ModLoader;
+ auto PP = createPP(Source, ModLoader);
+ auto TokLoc = SourceLocation();
+
+ hlsl::RootSignatureLexer Lexer(Source, TokLoc);
+ SmallVector<RootElement> Elements;
+ hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
+
+ // Test correct diagnostic produced
+ Consumer->setExpected(diag::err_hlsl_number_literal_overflow);
ASSERT_TRUE(Parser.parse());
ASSERT_TRUE(Consumer->isSatisfied());
diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
index c1b67844c747f..0ae8879b6c7d5 100644
--- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
@@ -23,6 +23,13 @@ namespace rootsig {
// Definitions of the in-memory data layout structures
+// Models the different registers: bReg | tReg | uReg | sReg
+enum class RegisterType { BReg, TReg, UReg, SReg };
+struct Register {
+ RegisterType ViewType;
+ uint32_t Number;
+};
+
// Models the end of a descriptor table and stores its visibility
struct DescriptorTable {
uint32_t NumClauses = 0; // The number of clauses in the table
@@ -32,11 +39,19 @@ struct DescriptorTable {
using ClauseType = llvm::dxil::ResourceClass;
struct DescriptorTableClause {
ClauseType Type;
+ Register Register;
+ uint32_t Space = 0;
};
// Models RootElement : DescriptorTable | DescriptorTableClause
using RootElement = std::variant<DescriptorTable, DescriptorTableClause>;
+// ParamType is used as an 'any' type that will reference to a parameter in
+// RootElement. Each variant of ParamType is expected to have a Parse method
+// defined that will be dispatched on when we are attempting to parse a
+// parameter
+using ParamType = std::variant<uint32_t *, Register *>;
+
} // namespace rootsig
} // namespace hlsl
} // namespace llvm
``````````
</details>
https://github.com/llvm/llvm-project/pull/133800
More information about the llvm-branch-commits
mailing list