[clang] [llvm] [HLSL][RootSignature] Implement Parsing of Descriptor Tables (PR #122982)
Finn Plummer via cfe-commits
cfe-commits at lists.llvm.org
Wed Mar 5 18:56:17 PST 2025
================
@@ -0,0 +1,417 @@
+#include "clang/Parse/ParseHLSLRootSignature.h"
+
+#include "clang/Lex/LiteralSupport.h"
+
+#include "llvm/Support/raw_ostream.h"
+
+using namespace llvm::hlsl::rootsig;
+
+namespace clang {
+namespace hlsl {
+
+static std::string FormatTokenKinds(ArrayRef<TokenKind> Kinds) {
+ std::string TokenString;
+ llvm::raw_string_ostream Out(TokenString);
+ bool First = true;
+ for (auto Kind : Kinds) {
+ if (!First)
+ Out << ", ";
+ switch (Kind) {
+#define TOK(X, SPELLING) \
+ case TokenKind::X: \
+ Out << SPELLING; \
+ break;
+#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
+ }
+ First = false;
+ }
+
+ return TokenString;
+}
+
+// Parser Definitions
+
+RootSignatureParser::RootSignatureParser(SmallVector<RootElement> &Elements,
+ RootSignatureLexer &Lexer,
+ Preprocessor &PP)
+ : Elements(Elements), Lexer(Lexer), PP(PP), CurToken(SourceLocation()) {}
+
+bool RootSignatureParser::Parse() {
+ // Handle edge-case of empty RootSignature()
+ if (Lexer.EndOfBuffer())
+ return false;
+
+ // Iterate as many RootElements as possible
+ while (!ParseRootElement()) {
+ if (Lexer.EndOfBuffer())
+ return false;
+ if (ConsumeExpectedToken(TokenKind::pu_comma))
+ return true;
+ }
+
+ return true;
+}
+
+bool RootSignatureParser::ParseRootElement() {
+ if (ConsumeExpectedToken(TokenKind::kw_DescriptorTable))
+ return true;
+
+ // Dispatch onto the correct parse method
+ switch (CurToken.Kind) {
+ case TokenKind::kw_DescriptorTable:
+ return ParseDescriptorTable();
+ default:
+ break;
+ }
+ llvm_unreachable("Switch for an expected token was not provided");
+}
+
+bool RootSignatureParser::ParseDescriptorTable() {
+ DescriptorTable Table;
+
+ if (ConsumeExpectedToken(TokenKind::pu_l_paren))
+ return true;
+
+ // Empty case:
+ if (TryConsumeExpectedToken(TokenKind::pu_r_paren)) {
+ Elements.push_back(Table);
+ return false;
+ }
+
+ bool SeenVisibility = false;
+ // Iterate through all the defined clauses
+ do {
+ // Handle the visibility parameter
+ if (TryConsumeExpectedToken(TokenKind::kw_visibility)) {
+ if (SeenVisibility) {
+ Diags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
+ << FormatTokenKinds(CurToken.Kind);
+ return true;
+ }
+ SeenVisibility = true;
+ if (ParseParam(&Table.Visibility))
+ return true;
+ continue;
+ }
+
+ // Otherwise, we expect a clause
+ if (ParseDescriptorTableClause())
+ return true;
+ Table.NumClauses++;
+ } while (TryConsumeExpectedToken(TokenKind::pu_comma));
+
+ if (ConsumeExpectedToken(TokenKind::pu_r_paren))
+ return true;
+
+ Elements.push_back(Table);
+ return false;
+}
+
+bool RootSignatureParser::ParseDescriptorTableClause() {
+ if (ConsumeExpectedToken({TokenKind::kw_CBV, TokenKind::kw_SRV,
+ TokenKind::kw_UAV, TokenKind::kw_Sampler}))
+ return true;
+
+ DescriptorTableClause Clause;
+ switch (CurToken.Kind) {
+ case TokenKind::kw_CBV:
+ Clause.Type = ClauseType::CBuffer;
+ break;
+ case TokenKind::kw_SRV:
+ Clause.Type = ClauseType::SRV;
+ break;
+ case TokenKind::kw_UAV:
+ Clause.Type = ClauseType::UAV;
+ break;
+ case TokenKind::kw_Sampler:
+ Clause.Type = ClauseType::Sampler;
+ break;
+ default:
+ llvm_unreachable("Switch for an expected token was not provided");
+ }
+ Clause.SetDefaultFlags();
+
+ if (ConsumeExpectedToken(TokenKind::pu_l_paren))
+ return true;
+
+ // Consume mandatory Register paramater
+ if (ParseRegister(&Clause.Register))
+ return true;
+
+ // Define optional paramaters
+ llvm::SmallDenseMap<TokenKind, ParamType> RefMap = {
+ {TokenKind::kw_numDescriptors, &Clause.NumDescriptors},
+ {TokenKind::kw_space, &Clause.Space},
+ {TokenKind::kw_offset, &Clause.Offset},
+ {TokenKind::kw_flags, &Clause.Flags},
+ };
+ if (ParseOptionalParams({RefMap}))
+ return true;
+
+ if (ConsumeExpectedToken(TokenKind::pu_r_paren))
+ return true;
+
+ Elements.push_back(Clause);
+ return false;
+}
+
+// Helper struct so that we can use the overloaded notation of std::visit
+template <class... Ts> struct ParseMethods : Ts... { using Ts::operator()...; };
+template <class... Ts> ParseMethods(Ts...) -> ParseMethods<Ts...>;
+
+bool RootSignatureParser::ParseParam(ParamType Ref) {
+ if (ConsumeExpectedToken(TokenKind::pu_equal))
+ return true;
+
+ bool Error;
+ std::visit(
+ ParseMethods{
+ [&](uint32_t *X) { Error = ParseUInt(X); },
+ [&](DescriptorRangeOffset *X) {
+ Error = ParseDescriptorRangeOffset(X);
+ },
+ [&](ShaderVisibility *Enum) { Error = ParseShaderVisibility(Enum); },
+ [&](DescriptorRangeFlags *Flags) {
+ Error = ParseDescriptorRangeFlags(Flags);
+ },
+ },
+ Ref);
+
+ return Error;
+}
+
+bool RootSignatureParser::ParseOptionalParams(
+ llvm::SmallDenseMap<TokenKind, ParamType> &RefMap) {
+ SmallVector<TokenKind> ParamKeywords;
+ for (auto RefPair : RefMap)
+ ParamKeywords.push_back(RefPair.first);
+
+ // Keep track of which keywords have been seen to report duplicates
+ llvm::SmallDenseSet<TokenKind> Seen;
+
+ while (TryConsumeExpectedToken(TokenKind::pu_comma)) {
+ if (ConsumeExpectedToken(ParamKeywords))
+ return true;
+
+ TokenKind ParamKind = CurToken.Kind;
+ if (Seen.contains(ParamKind)) {
+ Diags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
+ << FormatTokenKinds({ParamKind});
+ return true;
+ }
+ Seen.insert(ParamKind);
+
+ if (ParseParam(RefMap[ParamKind]))
+ return true;
+ }
+
+ return false;
+}
+
+bool RootSignatureParser::HandleUIntLiteral(uint32_t &X) {
+ // Parse the numeric value and do semantic checks on its specification
+ clang::NumericLiteralParser Literal(CurToken.NumSpelling, CurToken.TokLoc,
+ PP.getSourceManager(), PP.getLangOpts(),
+ PP.getTargetInfo(), PP.getDiagnostics());
+ if (Literal.hadError)
+ return true; // Error has already been reported so just return
+
+ assert(Literal.isIntegerLiteral() && "IsNumberChar will only support digits");
+
+ llvm::APSInt Val = llvm::APSInt(32, false);
+ if (Literal.GetIntegerValue(Val)) {
+ // Report that the value has overflowed
+ PP.getDiagnostics().Report(CurToken.TokLoc,
+ diag::err_hlsl_number_literal_overflow)
+ << 0 << CurToken.NumSpelling;
+ return true;
+ }
+
+ X = Val.getExtValue();
+ return false;
+}
+
+bool RootSignatureParser::ParseRegister(Register *Register) {
+ if (ConsumeExpectedToken(
+ {TokenKind::bReg, TokenKind::tReg, TokenKind::uReg, TokenKind::sReg}))
+ return true;
+
+ switch (CurToken.Kind) {
+ case TokenKind::bReg:
+ Register->ViewType = RegisterType::BReg;
+ break;
+ case TokenKind::tReg:
+ Register->ViewType = RegisterType::TReg;
+ break;
+ case TokenKind::uReg:
+ Register->ViewType = RegisterType::UReg;
+ break;
+ case TokenKind::sReg:
+ Register->ViewType = RegisterType::SReg;
+ break;
+ default:
+ llvm_unreachable("Switch for an expected token was not provided");
+ }
+
+ if (HandleUIntLiteral(Register->Number))
+ return true; // propogate NumericLiteralParser error
+
+ return false;
+}
+
+bool RootSignatureParser::ParseUInt(uint32_t *X) {
+ // Treat a postively signed integer as though it is unsigned to match DXC
+ TryConsumeExpectedToken(TokenKind::pu_plus);
+ if (ConsumeExpectedToken(TokenKind::int_literal))
+ return true;
+
+ if (HandleUIntLiteral(*X))
+ return true; // propogate NumericLiteralParser error
+
+ return false;
+}
+
+bool RootSignatureParser::ParseDescriptorRangeOffset(DescriptorRangeOffset *X) {
+ if (ConsumeExpectedToken(
+ {TokenKind::int_literal, TokenKind::en_DescriptorRangeOffsetAppend}))
+ return true;
+
+ // Edge case for the offset enum -> static value
+ if (CurToken.Kind == TokenKind::en_DescriptorRangeOffsetAppend) {
+ *X = DescriptorTableOffsetAppend;
+ return false;
+ }
+
+ uint32_t Temp;
+ if (HandleUIntLiteral(Temp))
+ return true; // propogate NumericLiteralParser error
+ *X = DescriptorRangeOffset(Temp);
+ return false;
+}
+
+template <bool AllowZero, typename EnumType>
+bool RootSignatureParser::ParseEnum(
+ llvm::SmallDenseMap<TokenKind, EnumType> &EnumMap, EnumType *Enum) {
+ SmallVector<TokenKind> EnumToks;
+ if (AllowZero)
+ EnumToks.push_back(TokenKind::int_literal); // '0' is a valid flag value
+ for (auto EnumPair : EnumMap)
+ EnumToks.push_back(EnumPair.first);
+
+ // If invoked we expect to have an enum
+ if (ConsumeExpectedToken(EnumToks))
+ return true;
+
+ // Handle the edge case when '0' is used to specify None
+ if (CurToken.Kind == TokenKind::int_literal) {
+ uint32_t Temp;
+ if (HandleUIntLiteral(Temp))
+ return true; // propogate NumericLiteralParser error
+ if (Temp != 0) {
+ Diags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_non_zero_flag);
+ return true;
+ }
+ // Set enum to None equivalent
+ *Enum = EnumType(0);
+ return false;
+ }
+
+ // Effectively a switch statement on the token kinds
+ for (auto EnumPair : EnumMap)
+ if (CurToken.Kind == EnumPair.first) {
+ *Enum = EnumPair.second;
+ return false;
+ }
+
+ llvm_unreachable("Switch for an expected token was not provided");
+}
+
+bool RootSignatureParser::ParseShaderVisibility(ShaderVisibility *Enum) {
+ // Define the possible flag kinds
+ llvm::SmallDenseMap<TokenKind, ShaderVisibility> EnumMap = {
+#define SHADER_VISIBILITY_ENUM(NAME, LIT) \
+ {TokenKind::en_##NAME, ShaderVisibility::NAME},
+#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
+ };
+
+ return ParseEnum(EnumMap, Enum);
+}
+
+template <typename FlagType>
+bool RootSignatureParser::ParseFlags(
+ llvm::SmallDenseMap<TokenKind, FlagType> &FlagMap, FlagType *Flags) {
+ // Override the default value to 0 so that we can correctly 'or' the values
+ *Flags = FlagType(0);
+
+ do {
+ FlagType Flag;
+ if (ParseEnum<true>(FlagMap, &Flag))
+ return true;
+ // Store the 'or'
+ *Flags |= Flag;
+ } while (TryConsumeExpectedToken(TokenKind::pu_or));
+
+ return false;
+}
+
+bool RootSignatureParser::ParseDescriptorRangeFlags(
+ DescriptorRangeFlags *Flags) {
+ // Define the possible flag kinds
+ llvm::SmallDenseMap<TokenKind, DescriptorRangeFlags> FlagMap = {
+#define DESCRIPTOR_RANGE_FLAG_ENUM(NAME, LIT, ON) \
+ {TokenKind::en_##NAME, DescriptorRangeFlags::NAME},
+#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
+ };
+
+ return ParseFlags(FlagMap, Flags);
+}
+
+// Is given token one of the expected kinds
+static bool IsExpectedToken(TokenKind Kind, ArrayRef<TokenKind> AnyExpected) {
+ for (auto Expected : AnyExpected)
+ if (Kind == Expected)
+ return true;
+ return false;
+}
+
+bool RootSignatureParser::PeekExpectedToken(TokenKind Expected) {
+ return PeekExpectedToken(ArrayRef{Expected});
+}
+
+bool RootSignatureParser::PeekExpectedToken(ArrayRef<TokenKind> AnyExpected) {
+ RootSignatureToken Result = Lexer.PeekNextToken();
+ return IsExpectedToken(Result.Kind, AnyExpected);
+}
+
+bool RootSignatureParser::ConsumeExpectedToken(TokenKind Expected) {
+ return ConsumeExpectedToken(ArrayRef{Expected});
+}
+
+bool RootSignatureParser::ConsumeExpectedToken(
+ ArrayRef<TokenKind> AnyExpected) {
+ ConsumeNextToken();
+ if (IsExpectedToken(CurToken.Kind, AnyExpected))
+ return false;
+
+ // Report unexpected token kind error
+ Diags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_unexpected_token_kind)
----------------
inbelic wrote:
I think this does effectively using a similar pattern (albeit with different names, which is confusing) as Clang, this function would be equivalent to the `ExpectAndConsume` function in the Clang's implementation.
`ExpectAndConsume` does however allow to pass down custom context specific diag messages. I have added a quick prototype commit to allow using custom diag messages to `ConsumeExpectedToken`.
Is this what you had intended? Do you think we should also change the function names to align with the Clang parser?
https://github.com/llvm/llvm-project/pull/122982
More information about the cfe-commits
mailing list