[clang] [llvm] [HLSL][RootSignature] Add infastructure to parse parameters (PR #133800)

Finn Plummer via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 17 12:07:07 PDT 2025


https://github.com/inbelic updated https://github.com/llvm/llvm-project/pull/133800

>From 9ff87eb37437dc92a554d1d89b236e9a13249694 Mon Sep 17 00:00:00 2001
From: Finn Plummer <finnplummer at microsoft.com>
Date: Mon, 31 Mar 2025 18:29:26 +0000
Subject: [PATCH 01/11] [HLSL][RootSignature] Add infastructure to parse
 parameters

- defines `ParamType` as a way to represent a reference to some
parameter in a root signature

- defines `ParseParam` and `ParseParams` as an infastructure to define
how the parameters of a given struct should be parsed in an orderless
manner
---
 .../clang/Basic/DiagnosticParseKinds.td       |  4 +-
 .../clang/Parse/ParseHLSLRootSignature.h      | 32 ++++++++++++
 clang/lib/Parse/ParseHLSLRootSignature.cpp    | 51 +++++++++++++++++++
 .../llvm/Frontend/HLSL/HLSLRootSignature.h    |  6 +++
 4 files changed, 92 insertions(+), 1 deletion(-)

diff --git a/clang/include/clang/Basic/DiagnosticParseKinds.td b/clang/include/clang/Basic/DiagnosticParseKinds.td
index 954f538e15026..2a0bf81c68e44 100644
--- a/clang/include/clang/Basic/DiagnosticParseKinds.td
+++ b/clang/include/clang/Basic/DiagnosticParseKinds.td
@@ -1830,8 +1830,10 @@ def err_hlsl_virtual_function
 def err_hlsl_virtual_inheritance
     : Error<"virtual inheritance is unsupported in HLSL">;
 
-// HLSL Root Siganture diagnostic messages
+// HLSL Root Signature Parser Diagnostics
 def err_hlsl_unexpected_end_of_params
     : Error<"expected %0 to denote end of parameters, or, another valid parameter of %1">;
+def err_hlsl_rootsig_repeat_param : Error<"specified the same parameter '%0' multiple times">;
+def err_hlsl_rootsig_missing_param : Error<"did not specify mandatory parameter '%0'">;
 
 } // end of Parser diagnostics
diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index 18cc2c6692551..55d84f91b8834 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -69,6 +69,38 @@ class RootSignatureParser {
   bool parseDescriptorTable();
   bool parseDescriptorTableClause();
 
+  /// Each unique ParamType will have a custom parse method defined that we can
+  /// use to invoke the parameters.
+  ///
+  /// This function will switch on the ParamType using std::visit and dispatch
+  /// onto the corresponding parse method
+  bool parseParam(llvm::hlsl::rootsig::ParamType Ref);
+
+  /// Parameter arguments (eg. `bReg`, `space`, ...) can be specified in any
+  /// order, exactly once, and only a subset are mandatory. This function acts
+  /// as the infastructure to do so in a declarative way.
+  ///
+  /// For the example:
+  ///  SmallDenseMap<TokenKind, ParamType> Params = {
+  ///    TokenKind::bReg, &Clause.Register,
+  ///    TokenKind::kw_space, &Clause.Space
+  ///  };
+  ///  SmallDenseSet<TokenKind> Mandatory = {
+  ///    TokenKind::kw_numDescriptors
+  ///  };
+  ///
+  /// We can read it is as:
+  ///
+  /// when 'b0' is encountered, invoke the parse method for the type
+  ///   of &Clause.Register (Register *) and update the parameter
+  /// when 'space' is encountered, invoke a parse method for the type
+  ///   of &Clause.Space (uint32_t *) and update the parameter
+  ///
+  /// and 'bReg' must be specified
+  bool parseParams(
+      llvm::SmallDenseMap<TokenKind, llvm::hlsl::rootsig::ParamType> &Params,
+      llvm::SmallDenseSet<TokenKind> &Mandatory);
+
   /// Invoke the Lexer to consume a token and update CurToken with the result
   void consumeNextToken() { CurToken = Lexer.ConsumeToken(); }
 
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 93a9689ebdf72..54c51e56b84dd 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -120,6 +120,57 @@ bool RootSignatureParser::parseDescriptorTableClause() {
   return false;
 }
 
+// Helper struct so that we can use the overloaded notation of std::visit
+template <class... Ts> struct ParseMethods : Ts... { using Ts::operator()...; };
+template <class... Ts> ParseMethods(Ts...) -> ParseMethods<Ts...>;
+
+bool RootSignatureParser::parseParam(ParamType Ref) {
+  bool Error = true;
+  std::visit(ParseMethods{}, Ref);
+
+  return Error;
+}
+
+bool RootSignatureParser::parseParams(
+    llvm::SmallDenseMap<TokenKind, ParamType> &Params,
+    llvm::SmallDenseSet<TokenKind> &Mandatory) {
+
+  // Initialize a vector of possible keywords
+  SmallVector<TokenKind> Keywords;
+  for (auto Pair : Params)
+    Keywords.push_back(Pair.first);
+
+  // Keep track of which keywords have been seen to report duplicates
+  llvm::SmallDenseSet<TokenKind> Seen;
+
+  while (tryConsumeExpectedToken(Keywords)) {
+    if (Seen.contains(CurToken.Kind)) {
+      getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
+          << CurToken.Kind;
+      return true;
+    }
+    Seen.insert(CurToken.Kind);
+
+    if (parseParam(Params[CurToken.Kind]))
+      return true;
+
+    if (!tryConsumeExpectedToken(TokenKind::pu_comma))
+      break;
+  }
+
+  bool AllMandatoryDefined = true;
+  for (auto Kind : Mandatory) {
+    bool SeenParam = Seen.contains(Kind);
+    if (!SeenParam) {
+      getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_missing_param)
+          << Kind;
+    }
+    AllMandatoryDefined &= SeenParam;
+  }
+
+  return !AllMandatoryDefined;
+}
+
 bool RootSignatureParser::peekExpectedToken(TokenKind Expected) {
   return peekExpectedToken(ArrayRef{Expected});
 }
diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
index c1b67844c747f..825beeea961cd 100644
--- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
@@ -37,6 +37,12 @@ struct DescriptorTableClause {
 // Models RootElement : DescriptorTable | DescriptorTableClause
 using RootElement = std::variant<DescriptorTable, DescriptorTableClause>;
 
+// ParamType is used as an 'any' type that will reference to a parameter in
+// RootElement. Each variant of ParamType is expected to have a Parse method
+// defined that will be dispatched on when we are attempting to parse a
+// parameter
+using ParamType = std::variant<std::monostate>;
+
 } // namespace rootsig
 } // namespace hlsl
 } // namespace llvm

>From 180c33c33e9ee6e2f38d7b092d564d295591bc81 Mon Sep 17 00:00:00 2001
From: Finn Plummer <finnplummer at microsoft.com>
Date: Thu, 27 Mar 2025 18:34:02 +0000
Subject: [PATCH 02/11] define ParseUInt and ParseRegister to plug into
 parameters

---
 .../clang/Basic/DiagnosticParseKinds.td       |   1 +
 .../clang/Parse/ParseHLSLRootSignature.h      |   8 +
 clang/lib/Parse/ParseHLSLRootSignature.cpp    | 105 +++++++++++--
 .../Parse/ParseHLSLRootSignatureTest.cpp      | 146 +++++++++++++++++-
 .../llvm/Frontend/HLSL/HLSLRootSignature.h    |  11 +-
 5 files changed, 257 insertions(+), 14 deletions(-)

diff --git a/clang/include/clang/Basic/DiagnosticParseKinds.td b/clang/include/clang/Basic/DiagnosticParseKinds.td
index 2a0bf81c68e44..e801dfb3b85f4 100644
--- a/clang/include/clang/Basic/DiagnosticParseKinds.td
+++ b/clang/include/clang/Basic/DiagnosticParseKinds.td
@@ -1835,5 +1835,6 @@ def err_hlsl_unexpected_end_of_params
     : Error<"expected %0 to denote end of parameters, or, another valid parameter of %1">;
 def err_hlsl_rootsig_repeat_param : Error<"specified the same parameter '%0' multiple times">;
 def err_hlsl_rootsig_missing_param : Error<"did not specify mandatory parameter '%0'">;
+def err_hlsl_number_literal_overflow : Error<"integer literal is too large to be represented as a 32-bit %select{signed |}0 integer type">;
 
 } // end of Parser diagnostics
diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index 55d84f91b8834..f015abe3549aa 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -101,6 +101,14 @@ class RootSignatureParser {
       llvm::SmallDenseMap<TokenKind, llvm::hlsl::rootsig::ParamType> &Params,
       llvm::SmallDenseSet<TokenKind> &Mandatory);
 
+  /// Parameter parse methods corresponding to a ParamType
+  bool parseUIntParam(uint32_t *X);
+  bool parseRegister(llvm::hlsl::rootsig::Register *Reg);
+
+  /// Use NumericLiteralParser to convert CurToken.NumSpelling into a unsigned
+  /// 32-bit integer
+  bool handleUIntLiteral(uint32_t *X);
+
   /// Invoke the Lexer to consume a token and update CurToken with the result
   void consumeNextToken() { CurToken = Lexer.ConsumeToken(); }
 
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 54c51e56b84dd..b599b9714a949 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -8,6 +8,8 @@
 
 #include "clang/Parse/ParseHLSLRootSignature.h"
 
+#include "clang/Lex/LiteralSupport.h"
+
 #include "llvm/Support/raw_ostream.h"
 
 using namespace llvm::hlsl::rootsig;
@@ -89,46 +91,70 @@ bool RootSignatureParser::parseDescriptorTableClause() {
           CurToken.TokKind == TokenKind::kw_UAV ||
           CurToken.TokKind == TokenKind::kw_Sampler) &&
          "Expects to only be invoked starting at given keyword");
+  TokenKind ParamKind = CurToken.TokKind; // retain for diagnostics
 
   DescriptorTableClause Clause;
-  switch (CurToken.TokKind) {
+  TokenKind ExpectedRegister;
+  switch (ParamKind) {
   default:
     llvm_unreachable("Switch for consumed token was not provided");
   case TokenKind::kw_CBV:
     Clause.Type = ClauseType::CBuffer;
+    ExpectedRegister = TokenKind::bReg;
     break;
   case TokenKind::kw_SRV:
     Clause.Type = ClauseType::SRV;
+    ExpectedRegister = TokenKind::tReg;
     break;
   case TokenKind::kw_UAV:
     Clause.Type = ClauseType::UAV;
+    ExpectedRegister = TokenKind::uReg;
     break;
   case TokenKind::kw_Sampler:
     Clause.Type = ClauseType::Sampler;
+    ExpectedRegister = TokenKind::sReg;
     break;
   }
 
   if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
-                           CurToken.TokKind))
+                           ParamKind))
     return true;
 
-  if (consumeExpectedToken(TokenKind::pu_r_paren, diag::err_expected_after,
-                           CurToken.TokKind))
+  llvm::SmallDenseMap<TokenKind, ParamType> Params = {
+      {ExpectedRegister, &Clause.Register},
+      {TokenKind::kw_space, &Clause.Space},
+  };
+  llvm::SmallDenseSet<TokenKind> Mandatory = {
+      ExpectedRegister,
+  };
+
+  if (parseParams(Params, Mandatory))
+    return true;
+
+  if (consumeExpectedToken(TokenKind::pu_r_paren, diag::err_hlsl_unexpected_end_of_params,
+                           ParamKind))
     return true;
 
   Elements.push_back(Clause);
   return false;
 }
 
-// Helper struct so that we can use the overloaded notation of std::visit
+// Helper struct defined to use the overloaded notation of std::visit.
 template <class... Ts> struct ParseMethods : Ts... { using Ts::operator()...; };
 template <class... Ts> ParseMethods(Ts...) -> ParseMethods<Ts...>;
 
 bool RootSignatureParser::parseParam(ParamType Ref) {
-  bool Error = true;
-  std::visit(ParseMethods{}, Ref);
-
-  return Error;
+  return std::visit(
+      ParseMethods{
+          [this](Register *X) -> bool { return parseRegister(X); },
+          [this](uint32_t *X) -> bool {
+            return consumeExpectedToken(TokenKind::pu_equal,
+                                        diag::err_expected_after,
+                                        CurToken.Kind) ||
+                   parseUIntParam(X);
+          },
+      },
+      Ref);
 }
 
 bool RootSignatureParser::parseParams(
@@ -171,6 +197,67 @@ bool RootSignatureParser::parseParams(
   return !AllMandatoryDefined;
 }
 
+bool RootSignatureParser::parseUIntParam(uint32_t *X) {
+  assert(CurToken.Kind == TokenKind::pu_equal &&
+         "Expects to only be invoked starting at given keyword");
+  tryConsumeExpectedToken(TokenKind::pu_plus);
+  return consumeExpectedToken(TokenKind::int_literal, diag::err_expected_after,
+                              CurToken.Kind) ||
+         handleUIntLiteral(X);
+}
+
+bool RootSignatureParser::parseRegister(Register *Register) {
+  assert(
+      (CurToken.Kind == TokenKind::bReg || CurToken.Kind == TokenKind::tReg ||
+       CurToken.Kind == TokenKind::uReg || CurToken.Kind == TokenKind::sReg) &&
+      "Expects to only be invoked starting at given keyword");
+
+  switch (CurToken.Kind) {
+  case TokenKind::bReg:
+    Register->ViewType = RegisterType::BReg;
+    break;
+  case TokenKind::tReg:
+    Register->ViewType = RegisterType::TReg;
+    break;
+  case TokenKind::uReg:
+    Register->ViewType = RegisterType::UReg;
+    break;
+  case TokenKind::sReg:
+    Register->ViewType = RegisterType::SReg;
+    break;
+  default:
+    break; // Unreachable given Try + assert pattern
+  }
+
+  if (handleUIntLiteral(&Register->Number))
+    return true; // propogate NumericLiteralParser error
+
+  return false;
+}
+
+bool RootSignatureParser::handleUIntLiteral(uint32_t *X) {
+  // Parse the numeric value and do semantic checks on its specification
+  clang::NumericLiteralParser Literal(CurToken.NumSpelling, CurToken.TokLoc,
+                                      PP.getSourceManager(), PP.getLangOpts(),
+                                      PP.getTargetInfo(), PP.getDiagnostics());
+  if (Literal.hadError)
+    return true; // Error has already been reported so just return
+
+  assert(Literal.isIntegerLiteral() && "IsNumberChar will only support digits");
+
+  llvm::APSInt Val = llvm::APSInt(32, false);
+  if (Literal.GetIntegerValue(Val)) {
+    // Report that the value has overflowed
+    PP.getDiagnostics().Report(CurToken.TokLoc,
+                               diag::err_hlsl_number_literal_overflow)
+        << 0 << CurToken.NumSpelling;
+    return true;
+  }
+
+  *X = Val.getExtValue();
+  return false;
+}
+
 bool RootSignatureParser::peekExpectedToken(TokenKind Expected) {
   return peekExpectedToken(ArrayRef{Expected});
 }
diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index acdf455a5d6aa..5b162e9a1b4cd 100644
--- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -130,10 +130,10 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseEmptyTest) {
 TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
   const llvm::StringLiteral Source = R"cc(
     DescriptorTable(
-      CBV(),
-      SRV(),
-      Sampler(),
-      UAV()
+      CBV(b0),
+      SRV(space = 3, t42),
+      Sampler(s987, space = +2),
+      UAV(u4294967294)
     ),
     DescriptorTable()
   )cc";
@@ -155,18 +155,34 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
   RootElement Elem = Elements[0];
   ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::CBuffer);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.ViewType,
+            RegisterType::BReg);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number, 0u);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 0u);
 
   Elem = Elements[1];
   ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::SRV);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.ViewType,
+            RegisterType::TReg);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number, 42u);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 3u);
 
   Elem = Elements[2];
   ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::Sampler);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.ViewType,
+            RegisterType::SReg);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number, 987u);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 2u);
 
   Elem = Elements[3];
   ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::UAV);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.ViewType,
+            RegisterType::UReg);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number, 4294967294u);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 0u);
 
   Elem = Elements[4];
   ASSERT_TRUE(std::holds_alternative<DescriptorTable>(Elem));
@@ -176,6 +192,32 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
   Elem = Elements[5];
   ASSERT_TRUE(std::holds_alternative<DescriptorTable>(Elem));
   ASSERT_EQ(std::get<DescriptorTable>(Elem).NumClauses, 0u);
+
+  ASSERT_TRUE(Consumer->isSatisfied());
+}
+
+TEST_F(ParseHLSLRootSignatureTest, ValidTrailingCommaTest) {
+  // This test will checks we can handling trailing commas ','
+  const llvm::StringLiteral Source = R"cc(
+    DescriptorTable(
+      CBV(b0, ),
+      SRV(t42),
+    )
+  )cc";
+
+  TrivialModuleLoader ModLoader;
+  auto PP = createPP(Source, ModLoader);
+  auto TokLoc = SourceLocation();
+
+  hlsl::RootSignatureLexer Lexer(Source, TokLoc);
+  SmallVector<RootElement> Elements;
+  hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
+
+  // Test no diagnostics produced
+  Consumer->setNoDiag();
+
+  ASSERT_FALSE(Parser.parse());
+
   ASSERT_TRUE(Consumer->isSatisfied());
 }
 
@@ -237,6 +279,102 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidParseUnexpectedEndOfStreamTest) {
 
   // Test correct diagnostic produced - end of stream
   Consumer->setExpected(diag::err_expected_after);
+
+  ASSERT_TRUE(Parser.parse());
+
+  ASSERT_TRUE(Consumer->isSatisfied());
+}
+
+TEST_F(ParseHLSLRootSignatureTest, InvalidMissingParameterTest) {
+  // This test will check that the parsing fails due a mandatory
+  // parameter (register) not being specified
+  const llvm::StringLiteral Source = R"cc(
+    DescriptorTable(
+      CBV()
+    )
+  )cc";
+
+  TrivialModuleLoader ModLoader;
+  auto PP = createPP(Source, ModLoader);
+  auto TokLoc = SourceLocation();
+
+  hlsl::RootSignatureLexer Lexer(Source, TokLoc);
+  SmallVector<RootElement> Elements;
+  hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
+
+  // Test correct diagnostic produced
+  Consumer->setExpected(diag::err_hlsl_rootsig_missing_param);
+  ASSERT_TRUE(Parser.parse());
+
+  ASSERT_TRUE(Consumer->isSatisfied());
+}
+
+TEST_F(ParseHLSLRootSignatureTest, InvalidRepeatedMandatoryParameterTest) {
+  // This test will check that the parsing fails due the same mandatory
+  // parameter being specified multiple times
+  const llvm::StringLiteral Source = R"cc(
+    DescriptorTable(
+      CBV(b32, b84)
+    )
+  )cc";
+
+  TrivialModuleLoader ModLoader;
+  auto PP = createPP(Source, ModLoader);
+  auto TokLoc = SourceLocation();
+
+  hlsl::RootSignatureLexer Lexer(Source, TokLoc);
+  SmallVector<RootElement> Elements;
+  hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
+
+  // Test correct diagnostic produced
+  Consumer->setExpected(diag::err_hlsl_rootsig_repeat_param);
+  ASSERT_TRUE(Parser.parse());
+
+  ASSERT_TRUE(Consumer->isSatisfied());
+}
+
+TEST_F(ParseHLSLRootSignatureTest, InvalidRepeatedOptionalParameterTest) {
+  // This test will check that the parsing fails due the same optional
+  // parameter being specified multiple times
+  const llvm::StringLiteral Source = R"cc(
+    DescriptorTable(
+      CBV(space = 2, space = 0)
+    )
+  )cc";
+
+  TrivialModuleLoader ModLoader;
+  auto PP = createPP(Source, ModLoader);
+  auto TokLoc = SourceLocation();
+
+  hlsl::RootSignatureLexer Lexer(Source, TokLoc);
+  SmallVector<RootElement> Elements;
+  hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
+
+  // Test correct diagnostic produced
+  Consumer->setExpected(diag::err_hlsl_rootsig_repeat_param);
+  ASSERT_TRUE(Parser.parse());
+
+  ASSERT_TRUE(Consumer->isSatisfied());
+}
+
+TEST_F(ParseHLSLRootSignatureTest, InvalidLexOverflowedNumberTest) {
+  // This test will check that the lexing fails due to an integer overflow
+  const llvm::StringLiteral Source = R"cc(
+    DescriptorTable(
+      CBV(b4294967296)
+    )
+  )cc";
+
+  TrivialModuleLoader ModLoader;
+  auto PP = createPP(Source, ModLoader);
+  auto TokLoc = SourceLocation();
+
+  hlsl::RootSignatureLexer Lexer(Source, TokLoc);
+  SmallVector<RootElement> Elements;
+  hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
+
+  // Test correct diagnostic produced
+  Consumer->setExpected(diag::err_hlsl_number_literal_overflow);
   ASSERT_TRUE(Parser.parse());
 
   ASSERT_TRUE(Consumer->isSatisfied());
diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
index 825beeea961cd..0ae8879b6c7d5 100644
--- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
@@ -23,6 +23,13 @@ namespace rootsig {
 
 // Definitions of the in-memory data layout structures
 
+// Models the different registers: bReg | tReg | uReg | sReg
+enum class RegisterType { BReg, TReg, UReg, SReg };
+struct Register {
+  RegisterType ViewType;
+  uint32_t Number;
+};
+
 // Models the end of a descriptor table and stores its visibility
 struct DescriptorTable {
   uint32_t NumClauses = 0; // The number of clauses in the table
@@ -32,6 +39,8 @@ struct DescriptorTable {
 using ClauseType = llvm::dxil::ResourceClass;
 struct DescriptorTableClause {
   ClauseType Type;
+  Register Register;
+  uint32_t Space = 0;
 };
 
 // Models RootElement : DescriptorTable | DescriptorTableClause
@@ -41,7 +50,7 @@ using RootElement = std::variant<DescriptorTable, DescriptorTableClause>;
 // RootElement. Each variant of ParamType is expected to have a Parse method
 // defined that will be dispatched on when we are attempting to parse a
 // parameter
-using ParamType = std::variant<std::monostate>;
+using ParamType = std::variant<uint32_t *, Register *>;
 
 } // namespace rootsig
 } // namespace hlsl

>From ebab0ca6d5f344152fd5db9c665f82ceeb303b20 Mon Sep 17 00:00:00 2001
From: Finn Plummer <finnplummer at microsoft.com>
Date: Mon, 31 Mar 2025 18:47:15 +0000
Subject: [PATCH 03/11] self-review: use consumeExpectedToken api to report
 unexpected_end_of_params diag

---
 clang/lib/Parse/ParseHLSLRootSignature.cpp | 20 +++++++++-----------
 1 file changed, 9 insertions(+), 11 deletions(-)

diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index b599b9714a949..65ff90092a4d7 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -43,12 +43,11 @@ bool RootSignatureParser::parse() {
       break;
   }
 
-  if (!tryConsumeExpectedToken(TokenKind::end_of_stream)) {
-    getDiags().Report(CurToken.TokLoc, diag::err_hlsl_unexpected_end_of_params)
-        << /*expected=*/TokenKind::end_of_stream
-        << /*param of=*/TokenKind::kw_RootSignature;
+  if (consumeExpectedToken(TokenKind::end_of_stream,
+                           diag::err_hlsl_unexpected_end_of_params,
+                           /*param of=*/TokenKind::kw_RootSignature))
     return true;
-  }
+
   return false;
 }
 
@@ -74,12 +73,10 @@ bool RootSignatureParser::parseDescriptorTable() {
       break;
   }
 
-  if (!tryConsumeExpectedToken(TokenKind::pu_r_paren)) {
-    getDiags().Report(CurToken.TokLoc, diag::err_hlsl_unexpected_end_of_params)
-        << /*expected=*/TokenKind::pu_r_paren
-        << /*param of=*/TokenKind::kw_DescriptorTable;
+  if (consumeExpectedToken(TokenKind::pu_r_paren,
+                           diag::err_hlsl_unexpected_end_of_params,
+                           /*param of=*/TokenKind::kw_DescriptorTable))
     return true;
-  }
 
   Elements.push_back(Table);
   return false;
@@ -132,7 +129,7 @@ bool RootSignatureParser::parseDescriptorTableClause() {
     return true;
 
   if (consumeExpectedToken(TokenKind::pu_r_paren, diag::err_hlsl_unexpected_end_of_params,
-                           ParamKind))
+                           /*param of=*/ParamKind))
     return true;
 
   Elements.push_back(Clause);
@@ -279,6 +276,7 @@ bool RootSignatureParser::consumeExpectedToken(TokenKind Expected,
   case diag::err_expected:
     DB << Expected;
     break;
+  case diag::err_hlsl_unexpected_end_of_params:
   case diag::err_expected_either:
   case diag::err_expected_after:
     DB << Expected << Context;

>From cd5c14228eea6eed0a7e551a846d29d4779484e9 Mon Sep 17 00:00:00 2001
From: Finn Plummer <finnplummer at microsoft.com>
Date: Wed, 2 Apr 2025 17:51:15 +0000
Subject: [PATCH 04/11] self-review: fix-up comments

---
 clang/include/clang/Parse/ParseHLSLRootSignature.h | 6 +++---
 clang/lib/Parse/ParseHLSLRootSignature.cpp         | 4 ++--
 2 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index f015abe3549aa..4c45af20464e5 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -69,8 +69,8 @@ class RootSignatureParser {
   bool parseDescriptorTable();
   bool parseDescriptorTableClause();
 
-  /// Each unique ParamType will have a custom parse method defined that we can
-  /// use to invoke the parameters.
+  /// Each unique ParamType will have a custom parse method defined that can be
+  /// invoked to set a value to the referenced paramtype.
   ///
   /// This function will switch on the ParamType using std::visit and dispatch
   /// onto the corresponding parse method
@@ -86,7 +86,7 @@ class RootSignatureParser {
   ///    TokenKind::kw_space, &Clause.Space
   ///  };
   ///  SmallDenseSet<TokenKind> Mandatory = {
-  ///    TokenKind::kw_numDescriptors
+  ///    TokenKind::bReg
   ///  };
   ///
   /// We can read it is as:
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 65ff90092a4d7..4b392a9924573 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -210,6 +210,8 @@ bool RootSignatureParser::parseRegister(Register *Register) {
       "Expects to only be invoked starting at given keyword");
 
   switch (CurToken.Kind) {
+  default:
+    llvm_unreachable("Switch for consumed token was not provided");
   case TokenKind::bReg:
     Register->ViewType = RegisterType::BReg;
     break;
@@ -222,8 +224,6 @@ bool RootSignatureParser::parseRegister(Register *Register) {
   case TokenKind::sReg:
     Register->ViewType = RegisterType::SReg;
     break;
-  default:
-    break; // Unreachable given Try + assert pattern
   }
 
   if (handleUIntLiteral(&Register->Number))

>From 490e6b99b4c99db5bf084c9151cacd91d607bbbc Mon Sep 17 00:00:00 2001
From: Finn Plummer <finnplummer at microsoft.com>
Date: Wed, 2 Apr 2025 18:01:05 +0000
Subject: [PATCH 05/11] rebase changes: update to build issue fixes

---
 .../clang/Parse/ParseHLSLRootSignature.h      | 14 ++++++-------
 clang/lib/Parse/ParseHLSLRootSignature.cpp    | 20 +++++++++----------
 2 files changed, 17 insertions(+), 17 deletions(-)

diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index 4c45af20464e5..ffa58e053b68b 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -81,12 +81,12 @@ class RootSignatureParser {
   /// as the infastructure to do so in a declarative way.
   ///
   /// For the example:
-  ///  SmallDenseMap<TokenKind, ParamType> Params = {
-  ///    TokenKind::bReg, &Clause.Register,
-  ///    TokenKind::kw_space, &Clause.Space
+  ///  SmallDenseMap<RootSignatureToken::Kind, ParamType> Params = {
+  ///    RootSignatureToken::Kind::bReg, &Clause.Register,
+  ///    RootSignatureToken::Kind::kw_space, &Clause.Space
   ///  };
-  ///  SmallDenseSet<TokenKind> Mandatory = {
-  ///    TokenKind::bReg
+  ///  SmallDenseSet<RootSignatureToken::Kind> Mandatory = {
+  ///    RootSignatureToken::Kind::bReg
   ///  };
   ///
   /// We can read it is as:
@@ -98,8 +98,8 @@ class RootSignatureParser {
   ///
   /// and 'bReg' must be specified
   bool parseParams(
-      llvm::SmallDenseMap<TokenKind, llvm::hlsl::rootsig::ParamType> &Params,
-      llvm::SmallDenseSet<TokenKind> &Mandatory);
+      llvm::SmallDenseMap<RootSignatureToken::Kind, llvm::hlsl::rootsig::ParamType> &Params,
+      llvm::SmallDenseSet<RootSignatureToken::Kind> &Mandatory);
 
   /// Parameter parse methods corresponding to a ParamType
   bool parseUIntParam(uint32_t *X);
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 4b392a9924573..4fcc4aca3788b 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -147,7 +147,7 @@ bool RootSignatureParser::parseParam(ParamType Ref) {
           [this](uint32_t *X) -> bool {
             return consumeExpectedToken(TokenKind::pu_equal,
                                         diag::err_expected_after,
-                                        CurToken.Kind) ||
+                                        CurToken.TokKind) ||
                    parseUIntParam(X);
           },
       },
@@ -167,14 +167,14 @@ bool RootSignatureParser::parseParams(
   llvm::SmallDenseSet<TokenKind> Seen;
 
   while (tryConsumeExpectedToken(Keywords)) {
-    if (Seen.contains(CurToken.Kind)) {
+    if (Seen.contains(CurToken.TokKind)) {
       getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
-          << CurToken.Kind;
+          << CurToken.TokKind;
       return true;
     }
-    Seen.insert(CurToken.Kind);
+    Seen.insert(CurToken.TokKind);
 
-    if (parseParam(Params[CurToken.Kind]))
+    if (parseParam(Params[CurToken.TokKind]))
       return true;
 
     if (!tryConsumeExpectedToken(TokenKind::pu_comma))
@@ -195,21 +195,21 @@ bool RootSignatureParser::parseParams(
 }
 
 bool RootSignatureParser::parseUIntParam(uint32_t *X) {
-  assert(CurToken.Kind == TokenKind::pu_equal &&
+  assert(CurToken.TokKind == TokenKind::pu_equal &&
          "Expects to only be invoked starting at given keyword");
   tryConsumeExpectedToken(TokenKind::pu_plus);
   return consumeExpectedToken(TokenKind::int_literal, diag::err_expected_after,
-                              CurToken.Kind) ||
+                              CurToken.TokKind) ||
          handleUIntLiteral(X);
 }
 
 bool RootSignatureParser::parseRegister(Register *Register) {
   assert(
-      (CurToken.Kind == TokenKind::bReg || CurToken.Kind == TokenKind::tReg ||
-       CurToken.Kind == TokenKind::uReg || CurToken.Kind == TokenKind::sReg) &&
+      (CurToken.TokKind == TokenKind::bReg || CurToken.TokKind == TokenKind::tReg ||
+       CurToken.TokKind == TokenKind::uReg || CurToken.TokKind == TokenKind::sReg) &&
       "Expects to only be invoked starting at given keyword");
 
-  switch (CurToken.Kind) {
+  switch (CurToken.TokKind) {
   default:
     llvm_unreachable("Switch for consumed token was not provided");
   case TokenKind::bReg:

>From 8155c606700c351d047ee7fe72e1c5012190d8e3 Mon Sep 17 00:00:00 2001
From: Finn Plummer <finnplummer at microsoft.com>
Date: Wed, 2 Apr 2025 18:03:38 +0000
Subject: [PATCH 06/11] clang-formatting

---
 clang/include/clang/Parse/ParseHLSLRootSignature.h |  6 +++---
 clang/lib/Parse/ParseHLSLRootSignature.cpp         | 12 +++++++-----
 2 files changed, 10 insertions(+), 8 deletions(-)

diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index ffa58e053b68b..afd9f4d589c3a 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -97,9 +97,9 @@ class RootSignatureParser {
   ///   of &Clause.Space (uint32_t *) and update the parameter
   ///
   /// and 'bReg' must be specified
-  bool parseParams(
-      llvm::SmallDenseMap<RootSignatureToken::Kind, llvm::hlsl::rootsig::ParamType> &Params,
-      llvm::SmallDenseSet<RootSignatureToken::Kind> &Mandatory);
+  bool parseParams(llvm::SmallDenseMap<RootSignatureToken::Kind,
+                                       llvm::hlsl::rootsig::ParamType> &Params,
+                   llvm::SmallDenseSet<RootSignatureToken::Kind> &Mandatory);
 
   /// Parameter parse methods corresponding to a ParamType
   bool parseUIntParam(uint32_t *X);
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 4fcc4aca3788b..eb8993fda9e7b 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -128,7 +128,8 @@ bool RootSignatureParser::parseDescriptorTableClause() {
   if (parseParams(Params, Mandatory))
     return true;
 
-  if (consumeExpectedToken(TokenKind::pu_r_paren, diag::err_hlsl_unexpected_end_of_params,
+  if (consumeExpectedToken(TokenKind::pu_r_paren,
+                           diag::err_hlsl_unexpected_end_of_params,
                            /*param of=*/ParamKind))
     return true;
 
@@ -204,10 +205,11 @@ bool RootSignatureParser::parseUIntParam(uint32_t *X) {
 }
 
 bool RootSignatureParser::parseRegister(Register *Register) {
-  assert(
-      (CurToken.TokKind == TokenKind::bReg || CurToken.TokKind == TokenKind::tReg ||
-       CurToken.TokKind == TokenKind::uReg || CurToken.TokKind == TokenKind::sReg) &&
-      "Expects to only be invoked starting at given keyword");
+  assert((CurToken.TokKind == TokenKind::bReg ||
+          CurToken.TokKind == TokenKind::tReg ||
+          CurToken.TokKind == TokenKind::uReg ||
+          CurToken.TokKind == TokenKind::sReg) &&
+         "Expects to only be invoked starting at given keyword");
 
   switch (CurToken.TokKind) {
   default:

>From 3a598fdbba7bb81867a7aa623cb2f95c7ece881f Mon Sep 17 00:00:00 2001
From: Finn Plummer <finnplummer at microsoft.com>
Date: Wed, 2 Apr 2025 17:46:59 +0000
Subject: [PATCH 07/11] self-review: clang-format fix up

---
 clang/lib/Parse/ParseHLSLRootSignature.cpp | 9 ++++++---
 1 file changed, 6 insertions(+), 3 deletions(-)

diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index eb8993fda9e7b..f58566cf6de3c 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -138,12 +138,15 @@ bool RootSignatureParser::parseDescriptorTableClause() {
 }
 
 // Helper struct defined to use the overloaded notation of std::visit.
-template <class... Ts> struct ParseMethods : Ts... { using Ts::operator()...; };
-template <class... Ts> ParseMethods(Ts...) -> ParseMethods<Ts...>;
+template <class... Ts> struct ParseParamTypeMethods : Ts... {
+  using Ts::operator()...;
+};
+template <class... Ts>
+ParseParamTypeMethods(Ts...) -> ParseParamTypeMethods<Ts...>;
 
 bool RootSignatureParser::parseParam(ParamType Ref) {
   return std::visit(
-      ParseMethods{
+      ParseParamTypeMethods{
           [this](Register *X) -> bool { return parseRegister(X); },
           [this](uint32_t *X) -> bool {
             return consumeExpectedToken(TokenKind::pu_equal,

>From 1d2bd91a01f31e544a6be16bd92a56bf8c9e8998 Mon Sep 17 00:00:00 2001
From: Finn Plummer <finnplummer at microsoft.com>
Date: Fri, 11 Apr 2025 22:15:17 +0000
Subject: [PATCH 08/11] review: prototype using a stateful struct for parsed
 params

pros:
- more explicit mapping to what parse method should be called based on
the current keyword/root element
- removes complexity of using `std::visit` on the parameter types
- allows for validations before the in-memory root element is
constructed

cons:
- we will need to bind the parsed values to the in-memory
representations for each element as opposed to just setting it directly.
but it makes the validations performed more explicit

notes:
- this does not enforce that we should do all validations after parsing
out, we do allow validations on the current token types (for instance
with register types)
---
 .../clang/Parse/ParseHLSLRootSignature.h      |  86 +++++++-----
 clang/lib/Parse/ParseHLSLRootSignature.cpp    | 129 ++++++++++--------
 2 files changed, 125 insertions(+), 90 deletions(-)

diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index afd9f4d589c3a..899deb3ad2bd4 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -69,45 +69,67 @@ class RootSignatureParser {
   bool parseDescriptorTable();
   bool parseDescriptorTableClause();
 
-  /// Each unique ParamType will have a custom parse method defined that can be
-  /// invoked to set a value to the referenced paramtype.
-  ///
-  /// This function will switch on the ParamType using std::visit and dispatch
-  /// onto the corresponding parse method
-  bool parseParam(llvm::hlsl::rootsig::ParamType Ref);
-
   /// Parameter arguments (eg. `bReg`, `space`, ...) can be specified in any
-  /// order, exactly once, and only a subset are mandatory. This function acts
-  /// as the infastructure to do so in a declarative way.
-  ///
-  /// For the example:
-  ///  SmallDenseMap<RootSignatureToken::Kind, ParamType> Params = {
-  ///    RootSignatureToken::Kind::bReg, &Clause.Register,
-  ///    RootSignatureToken::Kind::kw_space, &Clause.Space
-  ///  };
-  ///  SmallDenseSet<RootSignatureToken::Kind> Mandatory = {
-  ///    RootSignatureToken::Kind::bReg
-  ///  };
+  /// order and only exactly once. `ParsedParamState` provides a common
+  /// stateful structure to facilitate a common abstraction over collecting
+  /// parameters. By having a common return type we can follow clang's pattern
+  /// of first parsing out the values and then validating when we construct
+  /// the corresponding in-memory RootElements.
+  struct ParsedParamState {
+    // Parameter state to hold the parsed values. The value is only guarentted
+    // to be correct if one of its keyword bits is set in `Seen`.
+    // Eg) if any of the seen bits for `bReg, tReg, uReg, sReg` are set, then
+    // Register will have an initialized value
+    llvm::hlsl::rootsig::Register Register;
+    uint32_t Space;
+
+    // Seen retains whether or not the corresponding keyword has been
+    // seen
+    uint32_t Seen = 0u;
+
+    // Valid keywords for the different parameter types and its corresponding
+    // parsed value
+    ArrayRef<RootSignatureToken::Kind> Keywords;
+
+    // Context of which RootElement is retained to dispatch on the correct
+    // parse method.
+    // Eg) If we encounter kw_flags, it will depend on Context
+    // we are parsing to determine which `parse*Flags` method to call
+    RootSignatureToken::Kind Context;
+
+    // Retain start of params for reporting diagnostics
+    SourceLocation StartLoc;
+
+    // Must provide the starting context of the parameters
+    ParsedParamState(ArrayRef<RootSignatureToken::Kind> Keywords,
+                     RootSignatureToken::Kind Context, SourceLocation StartLoc)
+        : Keywords(Keywords), Context(Context), StartLoc(StartLoc) {}
+
+    // Helper functions to interact with Seen
+    size_t getKeywordIdx(RootSignatureToken::Kind Keyword);
+    bool checkAndSetSeen(RootSignatureToken::Kind Keyword);
+    bool checkAndClearSeen(RootSignatureToken::Kind Keyword);
+  };
+
+  /// Root signature parameters follow the form of `keyword` `=` `value`, or,
+  /// are a register. Given a keyword and the context of which `RootElement`
+  /// type we are parsing, we can dispatch onto the correct parseMethod to
+  /// parse a value into the `ParsedParamState`.
   ///
-  /// We can read it is as:
-  ///
-  /// when 'b0' is encountered, invoke the parse method for the type
-  ///   of &Clause.Register (Register *) and update the parameter
-  /// when 'space' is encountered, invoke a parse method for the type
-  ///   of &Clause.Space (uint32_t *) and update the parameter
-  ///
-  /// and 'bReg' must be specified
-  bool parseParams(llvm::SmallDenseMap<RootSignatureToken::Kind,
-                                       llvm::hlsl::rootsig::ParamType> &Params,
-                   llvm::SmallDenseSet<RootSignatureToken::Kind> &Mandatory);
+  /// This function implements the dispatch onto the correct parse method.
+  bool parseParam(ParsedParamState &Params);
+
+  /// Parses out a `ParsedParamState` for the caller to use for construction
+  /// of the in-memory representation of a Root Element.
+  bool parseParams(ParsedParamState &Params);
 
   /// Parameter parse methods corresponding to a ParamType
-  bool parseUIntParam(uint32_t *X);
-  bool parseRegister(llvm::hlsl::rootsig::Register *Reg);
+  bool parseUIntParam(uint32_t &X);
+  bool parseRegister(llvm::hlsl::rootsig::Register &Reg);
 
   /// Use NumericLiteralParser to convert CurToken.NumSpelling into a unsigned
   /// 32-bit integer
-  bool handleUIntLiteral(uint32_t *X);
+  bool handleUIntLiteral(uint32_t &X);
 
   /// Invoke the Lexer to consume a token and update CurToken with the result
   void consumeNextToken() { CurToken = Lexer.ConsumeToken(); }
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index f58566cf6de3c..34989a7af3e20 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -117,16 +117,26 @@ bool RootSignatureParser::parseDescriptorTableClause() {
                            ParamKind))
     return true;
 
-  llvm::SmallDenseMap<TokenKind, ParamType> Params = {
-      {ExpectedRegister, &Clause.Register},
-      {TokenKind::kw_space, &Clause.Space},
-  };
-  llvm::SmallDenseSet<TokenKind> Mandatory = {
+  TokenKind Keywords[2] = {
       ExpectedRegister,
+      TokenKind::kw_space,
   };
+  ParsedParamState Params(Keywords, ParamKind, CurToken.TokLoc);
+  if (parseParams(Params))
+    return true;
 
-  if (parseParams(Params, Mandatory))
+  // Mandatory parameters:
+  if (!Params.checkAndClearSeen(ExpectedRegister)) {
+    getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_missing_param)
+        << ExpectedRegister;
     return true;
+  }
+
+  Clause.Register = Params.Register;
+
+  // Optional parameters:
+  if (Params.checkAndClearSeen(TokenKind::kw_space))
+    Clause.Space = Params.Space;
 
   if (consumeExpectedToken(TokenKind::pu_r_paren,
                            diag::err_hlsl_unexpected_end_of_params,
@@ -137,68 +147,71 @@ bool RootSignatureParser::parseDescriptorTableClause() {
   return false;
 }
 
-// Helper struct defined to use the overloaded notation of std::visit.
-template <class... Ts> struct ParseParamTypeMethods : Ts... {
-  using Ts::operator()...;
-};
-template <class... Ts>
-ParseParamTypeMethods(Ts...) -> ParseParamTypeMethods<Ts...>;
-
-bool RootSignatureParser::parseParam(ParamType Ref) {
-  return std::visit(
-      ParseParamTypeMethods{
-          [this](Register *X) -> bool { return parseRegister(X); },
-          [this](uint32_t *X) -> bool {
-            return consumeExpectedToken(TokenKind::pu_equal,
-                                        diag::err_expected_after,
-                                        CurToken.TokKind) ||
-                   parseUIntParam(X);
-          },
-      },
-      Ref);
+size_t RootSignatureParser::ParsedParamState::getKeywordIdx(
+    RootSignatureToken::Kind Keyword) {
+  ArrayRef KeywordRef = Keywords;
+  auto It = llvm::find(KeywordRef, Keyword);
+  assert(It != KeywordRef.end() && "Did not provide a valid param keyword");
+  return std::distance(KeywordRef.begin(), It);
 }
 
-bool RootSignatureParser::parseParams(
-    llvm::SmallDenseMap<TokenKind, ParamType> &Params,
-    llvm::SmallDenseSet<TokenKind> &Mandatory) {
+bool RootSignatureParser::ParsedParamState::checkAndSetSeen(
+    RootSignatureToken::Kind Keyword) {
+  size_t Idx = getKeywordIdx(Keyword);
+  bool WasSeen = Seen & (1 << Idx);
+  Seen |= 1u << Idx;
+  return WasSeen;
+}
 
-  // Initialize a vector of possible keywords
-  SmallVector<TokenKind> Keywords;
-  for (auto Pair : Params)
-    Keywords.push_back(Pair.first);
+bool RootSignatureParser::ParsedParamState::checkAndClearSeen(
+    RootSignatureToken::Kind Keyword) {
+  size_t Idx = getKeywordIdx(Keyword);
+  bool WasSeen = Seen & (1 << Idx);
+  Seen &= ~(1u << Idx);
+  return WasSeen;
+}
 
-  // Keep track of which keywords have been seen to report duplicates
-  llvm::SmallDenseSet<TokenKind> Seen;
+bool RootSignatureParser::parseParam(ParsedParamState &Params) {
+  TokenKind Keyword = CurToken.TokKind;
+  if (Keyword == TokenKind::bReg || Keyword == TokenKind::tReg ||
+      Keyword == TokenKind::uReg || Keyword == TokenKind::sReg) {
+    return parseRegister(Params.Register);
+  }
+
+  if (consumeExpectedToken(TokenKind::pu_equal, diag::err_expected_after,
+                           Keyword))
+    return true;
 
-  while (tryConsumeExpectedToken(Keywords)) {
-    if (Seen.contains(CurToken.TokKind)) {
+  switch (Keyword) {
+  case RootSignatureToken::Kind::kw_space:
+    return parseUIntParam(Params.Space);
+  default:
+    llvm_unreachable("Switch for consumed keyword was not provided");
+  }
+}
+
+bool RootSignatureParser::parseParams(ParsedParamState &Params) {
+  assert(CurToken.TokKind == TokenKind::pu_l_paren &&
+         "Expects to only be invoked starting at given token");
+
+  while (tryConsumeExpectedToken(Params.Keywords)) {
+    if (Params.checkAndSetSeen(CurToken.TokKind)) {
       getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
           << CurToken.TokKind;
       return true;
     }
-    Seen.insert(CurToken.TokKind);
 
-    if (parseParam(Params[CurToken.TokKind]))
+    if (parseParam(Params))
       return true;
 
     if (!tryConsumeExpectedToken(TokenKind::pu_comma))
       break;
   }
 
-  bool AllMandatoryDefined = true;
-  for (auto Kind : Mandatory) {
-    bool SeenParam = Seen.contains(Kind);
-    if (!SeenParam) {
-      getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_missing_param)
-          << Kind;
-    }
-    AllMandatoryDefined &= SeenParam;
-  }
-
-  return !AllMandatoryDefined;
+  return false;
 }
 
-bool RootSignatureParser::parseUIntParam(uint32_t *X) {
+bool RootSignatureParser::parseUIntParam(uint32_t &X) {
   assert(CurToken.TokKind == TokenKind::pu_equal &&
          "Expects to only be invoked starting at given keyword");
   tryConsumeExpectedToken(TokenKind::pu_plus);
@@ -207,7 +220,7 @@ bool RootSignatureParser::parseUIntParam(uint32_t *X) {
          handleUIntLiteral(X);
 }
 
-bool RootSignatureParser::parseRegister(Register *Register) {
+bool RootSignatureParser::parseRegister(Register &Register) {
   assert((CurToken.TokKind == TokenKind::bReg ||
           CurToken.TokKind == TokenKind::tReg ||
           CurToken.TokKind == TokenKind::uReg ||
@@ -218,26 +231,26 @@ bool RootSignatureParser::parseRegister(Register *Register) {
   default:
     llvm_unreachable("Switch for consumed token was not provided");
   case TokenKind::bReg:
-    Register->ViewType = RegisterType::BReg;
+    Register.ViewType = RegisterType::BReg;
     break;
   case TokenKind::tReg:
-    Register->ViewType = RegisterType::TReg;
+    Register.ViewType = RegisterType::TReg;
     break;
   case TokenKind::uReg:
-    Register->ViewType = RegisterType::UReg;
+    Register.ViewType = RegisterType::UReg;
     break;
   case TokenKind::sReg:
-    Register->ViewType = RegisterType::SReg;
+    Register.ViewType = RegisterType::SReg;
     break;
   }
 
-  if (handleUIntLiteral(&Register->Number))
+  if (handleUIntLiteral(Register.Number))
     return true; // propogate NumericLiteralParser error
 
   return false;
 }
 
-bool RootSignatureParser::handleUIntLiteral(uint32_t *X) {
+bool RootSignatureParser::handleUIntLiteral(uint32_t &X) {
   // Parse the numeric value and do semantic checks on its specification
   clang::NumericLiteralParser Literal(CurToken.NumSpelling, CurToken.TokLoc,
                                       PP.getSourceManager(), PP.getLangOpts(),
@@ -256,7 +269,7 @@ bool RootSignatureParser::handleUIntLiteral(uint32_t *X) {
     return true;
   }
 
-  *X = Val.getExtValue();
+  X = Val.getExtValue();
   return false;
 }
 

>From 5d1ec2f8a1239d68111f930d5c329759db83ebc0 Mon Sep 17 00:00:00 2001
From: Finn Plummer <finnplummer at microsoft.com>
Date: Thu, 17 Apr 2025 18:11:06 +0000
Subject: [PATCH 09/11] review: remove abstraction of parsing arbitrary params

---
 .../clang/Parse/ParseHLSLRootSignature.h      | 56 +----------
 clang/lib/Parse/ParseHLSLRootSignature.cpp    | 99 ++++++-------------
 .../llvm/Frontend/HLSL/HLSLRootSignature.h    |  6 --
 3 files changed, 37 insertions(+), 124 deletions(-)

diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index 899deb3ad2bd4..baf516de3e570 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -69,59 +69,13 @@ class RootSignatureParser {
   bool parseDescriptorTable();
   bool parseDescriptorTableClause();
 
-  /// Parameter arguments (eg. `bReg`, `space`, ...) can be specified in any
-  /// order and only exactly once. `ParsedParamState` provides a common
-  /// stateful structure to facilitate a common abstraction over collecting
-  /// parameters. By having a common return type we can follow clang's pattern
-  /// of first parsing out the values and then validating when we construct
-  /// the corresponding in-memory RootElements.
-  struct ParsedParamState {
-    // Parameter state to hold the parsed values. The value is only guarentted
-    // to be correct if one of its keyword bits is set in `Seen`.
-    // Eg) if any of the seen bits for `bReg, tReg, uReg, sReg` are set, then
-    // Register will have an initialized value
-    llvm::hlsl::rootsig::Register Register;
-    uint32_t Space;
-
-    // Seen retains whether or not the corresponding keyword has been
-    // seen
-    uint32_t Seen = 0u;
-
-    // Valid keywords for the different parameter types and its corresponding
-    // parsed value
-    ArrayRef<RootSignatureToken::Kind> Keywords;
-
-    // Context of which RootElement is retained to dispatch on the correct
-    // parse method.
-    // Eg) If we encounter kw_flags, it will depend on Context
-    // we are parsing to determine which `parse*Flags` method to call
-    RootSignatureToken::Kind Context;
-
-    // Retain start of params for reporting diagnostics
-    SourceLocation StartLoc;
-
-    // Must provide the starting context of the parameters
-    ParsedParamState(ArrayRef<RootSignatureToken::Kind> Keywords,
-                     RootSignatureToken::Kind Context, SourceLocation StartLoc)
-        : Keywords(Keywords), Context(Context), StartLoc(StartLoc) {}
-
-    // Helper functions to interact with Seen
-    size_t getKeywordIdx(RootSignatureToken::Kind Keyword);
-    bool checkAndSetSeen(RootSignatureToken::Kind Keyword);
-    bool checkAndClearSeen(RootSignatureToken::Kind Keyword);
+  struct ParsedParams {
+    std::optional<llvm::hlsl::rootsig::Register> Register;
+    std::optional<uint32_t> Space;
   };
-
-  /// Root signature parameters follow the form of `keyword` `=` `value`, or,
-  /// are a register. Given a keyword and the context of which `RootElement`
-  /// type we are parsing, we can dispatch onto the correct parseMethod to
-  /// parse a value into the `ParsedParamState`.
-  ///
-  /// This function implements the dispatch onto the correct parse method.
-  bool parseParam(ParsedParamState &Params);
-
-  /// Parses out a `ParsedParamState` for the caller to use for construction
+  /// Parses out a `ParsedParams` for the caller to use for construction
   /// of the in-memory representation of a Root Element.
-  bool parseParams(ParsedParamState &Params);
+  bool parseDescriptorTableClauseParams(ParsedParams &Params, RootSignatureToken::Kind RegType);
 
   /// Parameter parse methods corresponding to a ParamType
   bool parseUIntParam(uint32_t &X);
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 34989a7af3e20..e7fd94eb75e75 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -117,26 +117,21 @@ bool RootSignatureParser::parseDescriptorTableClause() {
                            ParamKind))
     return true;
 
-  TokenKind Keywords[2] = {
-      ExpectedRegister,
-      TokenKind::kw_space,
-  };
-  ParsedParamState Params(Keywords, ParamKind, CurToken.TokLoc);
-  if (parseParams(Params))
+  ParsedParams Result;
+  if (parseDescriptorTableClauseParams(Result, ExpectedRegister))
     return true;
 
-  // Mandatory parameters:
-  if (!Params.checkAndClearSeen(ExpectedRegister)) {
+  // Check mandatory parameters were provided
+  if (!Result.Register.has_value()) {
     getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_missing_param)
         << ExpectedRegister;
     return true;
   }
 
-  Clause.Register = Params.Register;
+  Clause.Register = *Result.Register;
 
-  // Optional parameters:
-  if (Params.checkAndClearSeen(TokenKind::kw_space))
-    Clause.Space = Params.Space;
+  if (Result.Space)
+    Clause.Space = *Result.Space;
 
   if (consumeExpectedToken(TokenKind::pu_r_paren,
                            diag::err_hlsl_unexpected_end_of_params,
@@ -147,66 +142,36 @@ bool RootSignatureParser::parseDescriptorTableClause() {
   return false;
 }
 
-size_t RootSignatureParser::ParsedParamState::getKeywordIdx(
-    RootSignatureToken::Kind Keyword) {
-  ArrayRef KeywordRef = Keywords;
-  auto It = llvm::find(KeywordRef, Keyword);
-  assert(It != KeywordRef.end() && "Did not provide a valid param keyword");
-  return std::distance(KeywordRef.begin(), It);
-}
-
-bool RootSignatureParser::ParsedParamState::checkAndSetSeen(
-    RootSignatureToken::Kind Keyword) {
-  size_t Idx = getKeywordIdx(Keyword);
-  bool WasSeen = Seen & (1 << Idx);
-  Seen |= 1u << Idx;
-  return WasSeen;
-}
-
-bool RootSignatureParser::ParsedParamState::checkAndClearSeen(
-    RootSignatureToken::Kind Keyword) {
-  size_t Idx = getKeywordIdx(Keyword);
-  bool WasSeen = Seen & (1 << Idx);
-  Seen &= ~(1u << Idx);
-  return WasSeen;
-}
-
-bool RootSignatureParser::parseParam(ParsedParamState &Params) {
-  TokenKind Keyword = CurToken.TokKind;
-  if (Keyword == TokenKind::bReg || Keyword == TokenKind::tReg ||
-      Keyword == TokenKind::uReg || Keyword == TokenKind::sReg) {
-    return parseRegister(Params.Register);
-  }
-
-  if (consumeExpectedToken(TokenKind::pu_equal, diag::err_expected_after,
-                           Keyword))
-    return true;
-
-  switch (Keyword) {
-  case RootSignatureToken::Kind::kw_space:
-    return parseUIntParam(Params.Space);
-  default:
-    llvm_unreachable("Switch for consumed keyword was not provided");
-  }
-}
-
-bool RootSignatureParser::parseParams(ParsedParamState &Params) {
+bool RootSignatureParser::parseDescriptorTableClauseParams(ParsedParams &Params, TokenKind RegType) {
   assert(CurToken.TokKind == TokenKind::pu_l_paren &&
          "Expects to only be invoked starting at given token");
 
-  while (tryConsumeExpectedToken(Params.Keywords)) {
-    if (Params.checkAndSetSeen(CurToken.TokKind)) {
-      getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
+  do {
+    if (tryConsumeExpectedToken(RegType)) {
+      if (Params.Register.has_value()) {
+        getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
           << CurToken.TokKind;
-      return true;
+        return true;
+      }
+      Register Reg;
+      if (parseRegister(Reg))
+        return true;
+      Params.Register = Reg;
     }
-
-    if (parseParam(Params))
-      return true;
-
-    if (!tryConsumeExpectedToken(TokenKind::pu_comma))
-      break;
-  }
+    if (tryConsumeExpectedToken(TokenKind::kw_space)) {
+      if (Params.Space.has_value()) {
+        getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
+          << CurToken.TokKind;
+        return true;
+      }
+      if (consumeExpectedToken(TokenKind::pu_equal))
+        return true;
+      uint32_t Space;
+      if (parseUIntParam(Space))
+        return true;
+      Params.Space = Space;
+    }
+  } while (tryConsumeExpectedToken(TokenKind::pu_comma));
 
   return false;
 }
diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
index 0ae8879b6c7d5..6a35c114100b9 100644
--- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
@@ -46,12 +46,6 @@ struct DescriptorTableClause {
 // Models RootElement : DescriptorTable | DescriptorTableClause
 using RootElement = std::variant<DescriptorTable, DescriptorTableClause>;
 
-// ParamType is used as an 'any' type that will reference to a parameter in
-// RootElement. Each variant of ParamType is expected to have a Parse method
-// defined that will be dispatched on when we are attempting to parse a
-// parameter
-using ParamType = std::variant<uint32_t *, Register *>;
-
 } // namespace rootsig
 } // namespace hlsl
 } // namespace llvm

>From 47139b640854cbaf685f2baa36a217d8c60d0abd Mon Sep 17 00:00:00 2001
From: Finn Plummer <finnplummer at microsoft.com>
Date: Thu, 17 Apr 2025 18:45:05 +0000
Subject: [PATCH 10/11] NFC review: switch calling convention to use
 std::optional instead of bool

---
 .../clang/Parse/ParseHLSLRootSignature.h      |  14 +-
 clang/lib/Parse/ParseHLSLRootSignature.cpp    | 157 +++++++++---------
 2 files changed, 86 insertions(+), 85 deletions(-)

diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index baf516de3e570..d5ab166e2d35d 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -66,8 +66,9 @@ class RootSignatureParser {
   // expected, or, there is a lexing error
 
   /// Root Element parse methods:
-  bool parseDescriptorTable();
-  bool parseDescriptorTableClause();
+  std::optional<llvm::hlsl::rootsig::DescriptorTable> parseDescriptorTable();
+  std::optional<llvm::hlsl::rootsig::DescriptorTableClause>
+  parseDescriptorTableClause();
 
   struct ParsedParams {
     std::optional<llvm::hlsl::rootsig::Register> Register;
@@ -75,15 +76,16 @@ class RootSignatureParser {
   };
   /// Parses out a `ParsedParams` for the caller to use for construction
   /// of the in-memory representation of a Root Element.
-  bool parseDescriptorTableClauseParams(ParsedParams &Params, RootSignatureToken::Kind RegType);
+  std::optional<ParsedParams>
+  parseDescriptorTableClauseParams(RootSignatureToken::Kind RegType);
 
   /// Parameter parse methods corresponding to a ParamType
-  bool parseUIntParam(uint32_t &X);
-  bool parseRegister(llvm::hlsl::rootsig::Register &Reg);
+  std::optional<uint32_t> parseUIntParam();
+  std::optional<llvm::hlsl::rootsig::Register> parseRegister();
 
   /// Use NumericLiteralParser to convert CurToken.NumSpelling into a unsigned
   /// 32-bit integer
-  bool handleUIntLiteral(uint32_t &X);
+  std::optional<uint32_t> handleUIntLiteral();
 
   /// Invoke the Lexer to consume a token and update CurToken with the result
   void consumeNextToken() { CurToken = Lexer.ConsumeToken(); }
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index e7fd94eb75e75..ae84349fa42ee 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -26,22 +26,14 @@ RootSignatureParser::RootSignatureParser(SmallVector<RootElement> &Elements,
 
 bool RootSignatureParser::parse() {
   // Iterate as many RootElements as possible
-  while (tryConsumeExpectedToken(TokenKind::kw_DescriptorTable)) {
-    // Dispatch onto parser method.
-    // We guard against the unreachable here as we just ensured that CurToken
-    // will be one of the kinds in the while condition
-    switch (CurToken.TokKind) {
-    case TokenKind::kw_DescriptorTable:
-      if (parseDescriptorTable())
+  do {
+    if (tryConsumeExpectedToken(TokenKind::kw_DescriptorTable)) {
+      auto Table = parseDescriptorTable();
+      if (!Table.has_value())
         return true;
-      break;
-    default:
-      llvm_unreachable("Switch for consumed token was not provided");
+      Elements.push_back(*Table);
     }
-
-    if (!tryConsumeExpectedToken(TokenKind::pu_comma))
-      break;
-  }
+  } while (tryConsumeExpectedToken(TokenKind::pu_comma));
 
   if (consumeExpectedToken(TokenKind::end_of_stream,
                            diag::err_hlsl_unexpected_end_of_params,
@@ -51,147 +43,153 @@ bool RootSignatureParser::parse() {
   return false;
 }
 
-bool RootSignatureParser::parseDescriptorTable() {
+std::optional<DescriptorTable> RootSignatureParser::parseDescriptorTable() {
   assert(CurToken.TokKind == TokenKind::kw_DescriptorTable &&
          "Expects to only be invoked starting at given keyword");
 
-  DescriptorTable Table;
-
   if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
                            CurToken.TokKind))
-    return true;
+    return std::nullopt;
 
-  // Iterate as many Clauses as possible
-  while (tryConsumeExpectedToken({TokenKind::kw_CBV, TokenKind::kw_SRV,
-                                  TokenKind::kw_UAV, TokenKind::kw_Sampler})) {
-    if (parseDescriptorTableClause())
-      return true;
-
-    Table.NumClauses++;
+  DescriptorTable Table;
 
-    if (!tryConsumeExpectedToken(TokenKind::pu_comma))
-      break;
-  }
+  // Iterate as many Clauses as possible
+  do {
+    if (tryConsumeExpectedToken({TokenKind::kw_CBV, TokenKind::kw_SRV,
+                                 TokenKind::kw_UAV, TokenKind::kw_Sampler})) {
+      auto Clause = parseDescriptorTableClause();
+      if (!Clause.has_value())
+        return std::nullopt;
+      Elements.push_back(*Clause);
+      Table.NumClauses++;
+    }
+  } while (tryConsumeExpectedToken(TokenKind::pu_comma));
 
   if (consumeExpectedToken(TokenKind::pu_r_paren,
                            diag::err_hlsl_unexpected_end_of_params,
                            /*param of=*/TokenKind::kw_DescriptorTable))
-    return true;
+    return std::nullopt;
 
-  Elements.push_back(Table);
-  return false;
+  return Table;
 }
 
-bool RootSignatureParser::parseDescriptorTableClause() {
+std::optional<DescriptorTableClause>
+RootSignatureParser::parseDescriptorTableClause() {
   assert((CurToken.TokKind == TokenKind::kw_CBV ||
           CurToken.TokKind == TokenKind::kw_SRV ||
           CurToken.TokKind == TokenKind::kw_UAV ||
           CurToken.TokKind == TokenKind::kw_Sampler) &&
          "Expects to only be invoked starting at given keyword");
-  TokenKind ParamKind = CurToken.TokKind; // retain for diagnostics
+
+  TokenKind ParamKind = CurToken.TokKind;
+
+  if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
+                           CurToken.TokKind))
+    return std::nullopt;
 
   DescriptorTableClause Clause;
-  TokenKind ExpectedRegister;
+  TokenKind ExpectedReg;
   switch (ParamKind) {
   default:
     llvm_unreachable("Switch for consumed token was not provided");
   case TokenKind::kw_CBV:
     Clause.Type = ClauseType::CBuffer;
-    ExpectedRegister = TokenKind::bReg;
+    ExpectedReg = TokenKind::bReg;
     break;
   case TokenKind::kw_SRV:
     Clause.Type = ClauseType::SRV;
-    ExpectedRegister = TokenKind::tReg;
+    ExpectedReg = TokenKind::tReg;
     break;
   case TokenKind::kw_UAV:
     Clause.Type = ClauseType::UAV;
-    ExpectedRegister = TokenKind::uReg;
+    ExpectedReg = TokenKind::uReg;
     break;
   case TokenKind::kw_Sampler:
     Clause.Type = ClauseType::Sampler;
-    ExpectedRegister = TokenKind::sReg;
+    ExpectedReg = TokenKind::sReg;
     break;
   }
 
-  if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
-                           ParamKind))
-    return true;
-
-  ParsedParams Result;
-  if (parseDescriptorTableClauseParams(Result, ExpectedRegister))
-    return true;
+  auto Params = parseDescriptorTableClauseParams(ExpectedReg);
+  if (!Params.has_value())
+    return std::nullopt;
 
   // Check mandatory parameters were provided
-  if (!Result.Register.has_value()) {
+  if (!Params->Register.has_value()) {
     getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_missing_param)
-        << ExpectedRegister;
-    return true;
+        << ExpectedReg;
+    return std::nullopt;
   }
 
-  Clause.Register = *Result.Register;
+  Clause.Register = Params->Register.value();
 
-  if (Result.Space)
-    Clause.Space = *Result.Space;
+  // Fill in optional values
+  if (Params->Space.has_value())
+    Clause.Space = Params->Space.value();
 
   if (consumeExpectedToken(TokenKind::pu_r_paren,
                            diag::err_hlsl_unexpected_end_of_params,
                            /*param of=*/ParamKind))
-    return true;
+    return std::nullopt;
 
-  Elements.push_back(Clause);
-  return false;
+  return Clause;
 }
 
-bool RootSignatureParser::parseDescriptorTableClauseParams(ParsedParams &Params, TokenKind RegType) {
+std::optional<RootSignatureParser::ParsedParams>
+RootSignatureParser::parseDescriptorTableClauseParams(TokenKind RegType) {
   assert(CurToken.TokKind == TokenKind::pu_l_paren &&
          "Expects to only be invoked starting at given token");
 
+  ParsedParams Params;
   do {
     if (tryConsumeExpectedToken(RegType)) {
       if (Params.Register.has_value()) {
         getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
-          << CurToken.TokKind;
-        return true;
+            << CurToken.TokKind;
+        return std::nullopt;
       }
-      Register Reg;
-      if (parseRegister(Reg))
-        return true;
+      auto Reg = parseRegister();
+      if (!Reg.has_value())
+        return std::nullopt;
       Params.Register = Reg;
     }
+
     if (tryConsumeExpectedToken(TokenKind::kw_space)) {
       if (Params.Space.has_value()) {
         getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
-          << CurToken.TokKind;
-        return true;
+            << CurToken.TokKind;
+        return std::nullopt;
       }
       if (consumeExpectedToken(TokenKind::pu_equal))
-        return true;
-      uint32_t Space;
-      if (parseUIntParam(Space))
-        return true;
+        return std::nullopt;
+      auto Space = parseUIntParam();
+      if (!Space.has_value())
+        return std::nullopt;
       Params.Space = Space;
     }
   } while (tryConsumeExpectedToken(TokenKind::pu_comma));
 
-  return false;
+  return Params;
 }
 
-bool RootSignatureParser::parseUIntParam(uint32_t &X) {
+std::optional<uint32_t> RootSignatureParser::parseUIntParam() {
   assert(CurToken.TokKind == TokenKind::pu_equal &&
          "Expects to only be invoked starting at given keyword");
   tryConsumeExpectedToken(TokenKind::pu_plus);
-  return consumeExpectedToken(TokenKind::int_literal, diag::err_expected_after,
-                              CurToken.TokKind) ||
-         handleUIntLiteral(X);
+  if (consumeExpectedToken(TokenKind::int_literal, diag::err_expected_after,
+                           CurToken.TokKind))
+    return std::nullopt;
+  return handleUIntLiteral();
 }
 
-bool RootSignatureParser::parseRegister(Register &Register) {
+std::optional<Register> RootSignatureParser::parseRegister() {
   assert((CurToken.TokKind == TokenKind::bReg ||
           CurToken.TokKind == TokenKind::tReg ||
           CurToken.TokKind == TokenKind::uReg ||
           CurToken.TokKind == TokenKind::sReg) &&
          "Expects to only be invoked starting at given keyword");
 
+  Register Register;
   switch (CurToken.TokKind) {
   default:
     llvm_unreachable("Switch for consumed token was not provided");
@@ -209,13 +207,15 @@ bool RootSignatureParser::parseRegister(Register &Register) {
     break;
   }
 
-  if (handleUIntLiteral(Register.Number))
-    return true; // propogate NumericLiteralParser error
+  auto Number = handleUIntLiteral();
+  if (!Number.has_value())
+    return std::nullopt; // propogate NumericLiteralParser error
 
-  return false;
+  Register.Number = *Number;
+  return Register;
 }
 
-bool RootSignatureParser::handleUIntLiteral(uint32_t &X) {
+std::optional<uint32_t> RootSignatureParser::handleUIntLiteral() {
   // Parse the numeric value and do semantic checks on its specification
   clang::NumericLiteralParser Literal(CurToken.NumSpelling, CurToken.TokLoc,
                                       PP.getSourceManager(), PP.getLangOpts(),
@@ -231,11 +231,10 @@ bool RootSignatureParser::handleUIntLiteral(uint32_t &X) {
     PP.getDiagnostics().Report(CurToken.TokLoc,
                                diag::err_hlsl_number_literal_overflow)
         << 0 << CurToken.NumSpelling;
-    return true;
+    return std::nullopt;
   }
 
-  X = Val.getExtValue();
-  return false;
+  return Val.getExtValue();
 }
 
 bool RootSignatureParser::peekExpectedToken(TokenKind Expected) {

>From 2df5d6d23efa7a6cf6914a696915cb9f9b4e3844 Mon Sep 17 00:00:00 2001
From: Finn Plummer <finnplummer at microsoft.com>
Date: Thu, 17 Apr 2025 18:55:07 +0000
Subject: [PATCH 11/11] NFC self-review: update comments

---
 .../clang/Parse/ParseHLSLRootSignature.h      | 36 +++++++++++--------
 clang/lib/Parse/ParseHLSLRootSignature.cpp    | 11 ++++--
 2 files changed, 30 insertions(+), 17 deletions(-)

diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index d5ab166e2d35d..e4e63e2e999a9 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -40,26 +40,31 @@ class RootSignatureParser {
 private:
   DiagnosticsEngine &getDiags() { return PP.getDiagnostics(); }
 
-  // All private Parse.* methods follow a similar pattern:
+  // All private parse.* methods follow a similar pattern:
   //   - Each method will start with an assert to denote what the CurToken is
   // expected to be and will parse from that token forward
   //
   //   - Therefore, it is the callers responsibility to ensure that you are
   // at the correct CurToken. This should be done with the pattern of:
   //
-  //  if (TryConsumeExpectedToken(RootSignatureToken::Kind))
-  //    if (Parse.*())
-  //      return true;
+  //  if (tryConsumeExpectedToken(RootSignatureToken::Kind)) {
+  //    auto ParsedObject = parse.*();
+  //    if (!ParsedObject.has_value())
+  //      return std::nullopt;
+  //    ...
+  // }
   //
   // or,
   //
-  //  if (ConsumeExpectedToken(RootSignatureToken::Kind, ...))
-  //    return true;
-  //  if (Parse.*())
-  //    return true;
+  //  if (consumeExpectedToken(RootSignatureToken::Kind, ...))
+  //    return std::nullopt;
+  //  auto ParsedObject = parse.*();
+  //  if (!ParsedObject.has_value())
+  //    return std::nullopt;
+  //  ...
   //
-  //   - All methods return true if a parsing error is encountered. It is the
-  // callers responsibility to propogate this error up, or deal with it
+  //   - All methods return std::nullopt if a parsing error is encountered. It
+  // is the callers responsibility to propogate this error up, or deal with it
   // otherwise
   //
   //   - An error will be raised if the proceeding tokens are not what is
@@ -70,16 +75,17 @@ class RootSignatureParser {
   std::optional<llvm::hlsl::rootsig::DescriptorTableClause>
   parseDescriptorTableClause();
 
-  struct ParsedParams {
+  /// Parameter parse methods
+  /// Parameter arguments (eg. `bReg`, `space`, ...) can be specified in any
+  /// order and only exactly once. `ParsedClauseParams` denotes the current
+  /// state of parsed params
+  struct ParsedClauseParams {
     std::optional<llvm::hlsl::rootsig::Register> Register;
     std::optional<uint32_t> Space;
   };
-  /// Parses out a `ParsedParams` for the caller to use for construction
-  /// of the in-memory representation of a Root Element.
-  std::optional<ParsedParams>
+  std::optional<ParsedClauseParams>
   parseDescriptorTableClauseParams(RootSignatureToken::Kind RegType);
 
-  /// Parameter parse methods corresponding to a ParamType
   std::optional<uint32_t> parseUIntParam();
   std::optional<llvm::hlsl::rootsig::Register> parseRegister();
 
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index ae84349fa42ee..698c6f2b7351d 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -135,13 +135,17 @@ RootSignatureParser::parseDescriptorTableClause() {
   return Clause;
 }
 
-std::optional<RootSignatureParser::ParsedParams>
+std::optional<RootSignatureParser::ParsedClauseParams>
 RootSignatureParser::parseDescriptorTableClauseParams(TokenKind RegType) {
   assert(CurToken.TokKind == TokenKind::pu_l_paren &&
          "Expects to only be invoked starting at given token");
 
-  ParsedParams Params;
+  // Parameter arguments (eg. `bReg`, `space`, ...) can be specified in any
+  // order and only exactly once. Parse through as many arguments as possible
+  // reporting an error if a duplicate is seen.
+  ParsedClauseParams Params;
   do {
+    // ( `b` | `t` | `u` | `s`) POS_INT
     if (tryConsumeExpectedToken(RegType)) {
       if (Params.Register.has_value()) {
         getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
@@ -154,14 +158,17 @@ RootSignatureParser::parseDescriptorTableClauseParams(TokenKind RegType) {
       Params.Register = Reg;
     }
 
+    // `space` `=` POS_INT
     if (tryConsumeExpectedToken(TokenKind::kw_space)) {
       if (Params.Space.has_value()) {
         getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
             << CurToken.TokKind;
         return std::nullopt;
       }
+
       if (consumeExpectedToken(TokenKind::pu_equal))
         return std::nullopt;
+
       auto Space = parseUIntParam();
       if (!Space.has_value())
         return std::nullopt;



More information about the llvm-commits mailing list