[llvm-branch-commits] [clang] [llvm] [HLSL][RootSignature] Add infastructure to parse parameters (PR #133800)

Finn Plummer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Mar 31 14:21:25 PDT 2025


https://github.com/inbelic created https://github.com/llvm/llvm-project/pull/133800

- defines `ParamType` as a way to represent a reference to some
parameter in a root signature

- defines `ParseParam` and `ParseParams` as an infastructure to define
how the parameters of a given struct should be parsed in an orderless
manner

- implements parsing of two param types: `UInt32` and `Register` to
demonstrate the parsing implementation and allow for unit testing

Part two of implementing: https://github.com/llvm/llvm-project/issues/126569

>From 71ae66142abb2e6e43a60ce521dc638ec702399f Mon Sep 17 00:00:00 2001
From: Finn Plummer <finnplummer at microsoft.com>
Date: Mon, 31 Mar 2025 18:29:26 +0000
Subject: [PATCH 1/3] [HLSL][RootSignature] Add infastructure to parse
 parameters

- defines `ParamType` as a way to represent a reference to some
parameter in a root signature

- defines `ParseParam` and `ParseParams` as an infastructure to define
how the parameters of a given struct should be parsed in an orderless
manner
---
 .../clang/Basic/DiagnosticParseKinds.td       |  4 +-
 .../clang/Parse/ParseHLSLRootSignature.h      | 32 ++++++++++++
 clang/lib/Parse/ParseHLSLRootSignature.cpp    | 51 +++++++++++++++++++
 .../llvm/Frontend/HLSL/HLSLRootSignature.h    |  6 +++
 4 files changed, 92 insertions(+), 1 deletion(-)

diff --git a/clang/include/clang/Basic/DiagnosticParseKinds.td b/clang/include/clang/Basic/DiagnosticParseKinds.td
index 2582e1e5ef0f6..f25c9c930cc61 100644
--- a/clang/include/clang/Basic/DiagnosticParseKinds.td
+++ b/clang/include/clang/Basic/DiagnosticParseKinds.td
@@ -1830,8 +1830,10 @@ def err_hlsl_virtual_function
 def err_hlsl_virtual_inheritance
     : Error<"virtual inheritance is unsupported in HLSL">;
 
-// HLSL Root Siganture diagnostic messages
+// HLSL Root Signature Parser Diagnostics
 def err_hlsl_unexpected_end_of_params
     : Error<"expected %0 to denote end of parameters, or, another valid parameter of %1">;
+def err_hlsl_rootsig_repeat_param : Error<"specified the same parameter '%0' multiple times">;
+def err_hlsl_rootsig_missing_param : Error<"did not specify mandatory parameter '%0'">;
 
 } // end of Parser diagnostics
diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index 43b41315b88b5..b39267dbaaa28 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -69,6 +69,38 @@ class RootSignatureParser {
   bool parseDescriptorTable();
   bool parseDescriptorTableClause();
 
+  /// Each unique ParamType will have a custom parse method defined that we can
+  /// use to invoke the parameters.
+  ///
+  /// This function will switch on the ParamType using std::visit and dispatch
+  /// onto the corresponding parse method
+  bool parseParam(llvm::hlsl::rootsig::ParamType Ref);
+
+  /// Parameter arguments (eg. `bReg`, `space`, ...) can be specified in any
+  /// order, exactly once, and only a subset are mandatory. This function acts
+  /// as the infastructure to do so in a declarative way.
+  ///
+  /// For the example:
+  ///  SmallDenseMap<TokenKind, ParamType> Params = {
+  ///    TokenKind::bReg, &Clause.Register,
+  ///    TokenKind::kw_space, &Clause.Space
+  ///  };
+  ///  SmallDenseSet<TokenKind> Mandatory = {
+  ///    TokenKind::kw_numDescriptors
+  ///  };
+  ///
+  /// We can read it is as:
+  ///
+  /// when 'b0' is encountered, invoke the parse method for the type
+  ///   of &Clause.Register (Register *) and update the parameter
+  /// when 'space' is encountered, invoke a parse method for the type
+  ///   of &Clause.Space (uint32_t *) and update the parameter
+  ///
+  /// and 'bReg' must be specified
+  bool parseParams(
+      llvm::SmallDenseMap<TokenKind, llvm::hlsl::rootsig::ParamType> &Params,
+      llvm::SmallDenseSet<TokenKind> &Mandatory);
+
   /// Invoke the Lexer to consume a token and update CurToken with the result
   void consumeNextToken() { CurToken = Lexer.ConsumeToken(); }
 
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 33caca5fa1c82..8bb78def243fe 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -118,6 +118,57 @@ bool RootSignatureParser::parseDescriptorTableClause() {
   return false;
 }
 
+// Helper struct so that we can use the overloaded notation of std::visit
+template <class... Ts> struct ParseMethods : Ts... { using Ts::operator()...; };
+template <class... Ts> ParseMethods(Ts...) -> ParseMethods<Ts...>;
+
+bool RootSignatureParser::parseParam(ParamType Ref) {
+  bool Error = true;
+  std::visit(ParseMethods{}, Ref);
+
+  return Error;
+}
+
+bool RootSignatureParser::parseParams(
+    llvm::SmallDenseMap<TokenKind, ParamType> &Params,
+    llvm::SmallDenseSet<TokenKind> &Mandatory) {
+
+  // Initialize a vector of possible keywords
+  SmallVector<TokenKind> Keywords;
+  for (auto Pair : Params)
+    Keywords.push_back(Pair.first);
+
+  // Keep track of which keywords have been seen to report duplicates
+  llvm::SmallDenseSet<TokenKind> Seen;
+
+  while (tryConsumeExpectedToken(Keywords)) {
+    if (Seen.contains(CurToken.Kind)) {
+      getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
+          << CurToken.Kind;
+      return true;
+    }
+    Seen.insert(CurToken.Kind);
+
+    if (parseParam(Params[CurToken.Kind]))
+      return true;
+
+    if (!tryConsumeExpectedToken(TokenKind::pu_comma))
+      break;
+  }
+
+  bool AllMandatoryDefined = true;
+  for (auto Kind : Mandatory) {
+    bool SeenParam = Seen.contains(Kind);
+    if (!SeenParam) {
+      getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_missing_param)
+          << Kind;
+    }
+    AllMandatoryDefined &= SeenParam;
+  }
+
+  return !AllMandatoryDefined;
+}
+
 bool RootSignatureParser::peekExpectedToken(TokenKind Expected) {
   return peekExpectedToken(ArrayRef{Expected});
 }
diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
index c1b67844c747f..825beeea961cd 100644
--- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
@@ -37,6 +37,12 @@ struct DescriptorTableClause {
 // Models RootElement : DescriptorTable | DescriptorTableClause
 using RootElement = std::variant<DescriptorTable, DescriptorTableClause>;
 
+// ParamType is used as an 'any' type that will reference to a parameter in
+// RootElement. Each variant of ParamType is expected to have a Parse method
+// defined that will be dispatched on when we are attempting to parse a
+// parameter
+using ParamType = std::variant<std::monostate>;
+
 } // namespace rootsig
 } // namespace hlsl
 } // namespace llvm

>From fb4e85dbda846dcde6f0a472ff2cfa46c5eced7d Mon Sep 17 00:00:00 2001
From: Finn Plummer <finnplummer at microsoft.com>
Date: Thu, 27 Mar 2025 18:34:02 +0000
Subject: [PATCH 2/3] define ParseUInt and ParseRegister to plug into
 parameters

---
 .../clang/Basic/DiagnosticParseKinds.td       |   1 +
 .../clang/Parse/ParseHLSLRootSignature.h      |   8 +
 clang/lib/Parse/ParseHLSLRootSignature.cpp    | 108 +++++++++++--
 .../Parse/ParseHLSLRootSignatureTest.cpp      | 146 +++++++++++++++++-
 .../llvm/Frontend/HLSL/HLSLRootSignature.h    |  11 +-
 5 files changed, 260 insertions(+), 14 deletions(-)

diff --git a/clang/include/clang/Basic/DiagnosticParseKinds.td b/clang/include/clang/Basic/DiagnosticParseKinds.td
index f25c9c930cc61..ab12159ba5ae1 100644
--- a/clang/include/clang/Basic/DiagnosticParseKinds.td
+++ b/clang/include/clang/Basic/DiagnosticParseKinds.td
@@ -1835,5 +1835,6 @@ def err_hlsl_unexpected_end_of_params
     : Error<"expected %0 to denote end of parameters, or, another valid parameter of %1">;
 def err_hlsl_rootsig_repeat_param : Error<"specified the same parameter '%0' multiple times">;
 def err_hlsl_rootsig_missing_param : Error<"did not specify mandatory parameter '%0'">;
+def err_hlsl_number_literal_overflow : Error<"integer literal is too large to be represented as a 32-bit %select{signed |}0 integer type">;
 
 } // end of Parser diagnostics
diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index b39267dbaaa28..02e99e83875db 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -101,6 +101,14 @@ class RootSignatureParser {
       llvm::SmallDenseMap<TokenKind, llvm::hlsl::rootsig::ParamType> &Params,
       llvm::SmallDenseSet<TokenKind> &Mandatory);
 
+  /// Parameter parse methods corresponding to a ParamType
+  bool parseUIntParam(uint32_t *X);
+  bool parseRegister(llvm::hlsl::rootsig::Register *Reg);
+
+  /// Use NumericLiteralParser to convert CurToken.NumSpelling into a unsigned
+  /// 32-bit integer
+  bool handleUIntLiteral(uint32_t *X);
+
   /// Invoke the Lexer to consume a token and update CurToken with the result
   void consumeNextToken() { CurToken = Lexer.ConsumeToken(); }
 
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 8bb78def243fe..6c6f154d07daf 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -8,6 +8,8 @@
 
 #include "clang/Parse/ParseHLSLRootSignature.h"
 
+#include "clang/Lex/LiteralSupport.h"
+
 #include "llvm/Support/raw_ostream.h"
 
 using namespace llvm::hlsl::rootsig;
@@ -87,46 +89,73 @@ bool RootSignatureParser::parseDescriptorTableClause() {
           CurToken.Kind == TokenKind::kw_UAV ||
           CurToken.Kind == TokenKind::kw_Sampler) &&
          "Expects to only be invoked starting at given keyword");
+  TokenKind ParamKind = CurToken.Kind; // retain for diagnostics
 
   DescriptorTableClause Clause;
-  switch (CurToken.Kind) {
+  TokenKind ExpectedRegister;
+  switch (ParamKind) {
   default:
     llvm_unreachable("Switch for consumed token was not provided");
   case TokenKind::kw_CBV:
     Clause.Type = ClauseType::CBuffer;
+    ExpectedRegister = TokenKind::bReg;
     break;
   case TokenKind::kw_SRV:
     Clause.Type = ClauseType::SRV;
+    ExpectedRegister = TokenKind::tReg;
     break;
   case TokenKind::kw_UAV:
     Clause.Type = ClauseType::UAV;
+    ExpectedRegister = TokenKind::uReg;
     break;
   case TokenKind::kw_Sampler:
     Clause.Type = ClauseType::Sampler;
+    ExpectedRegister = TokenKind::sReg;
     break;
   }
 
   if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
-                           CurToken.Kind))
+                           ParamKind))
     return true;
 
-  if (consumeExpectedToken(TokenKind::pu_r_paren, diag::err_expected_after,
-                           CurToken.Kind))
+  llvm::SmallDenseMap<TokenKind, ParamType> Params = {
+      {ExpectedRegister, &Clause.Register},
+      {TokenKind::kw_space, &Clause.Space},
+  };
+  llvm::SmallDenseSet<TokenKind> Mandatory = {
+      ExpectedRegister,
+  };
+
+  if (parseParams(Params, Mandatory))
     return true;
 
+  if (!tryConsumeExpectedToken(TokenKind::pu_r_paren)) {
+    getDiags().Report(CurToken.TokLoc, diag::err_hlsl_unexpected_end_of_params)
+        << /*expected=*/TokenKind::pu_r_paren
+        << /*param of=*/ParamKind;
+    return true;
+  }
+
   Elements.push_back(Clause);
   return false;
 }
 
-// Helper struct so that we can use the overloaded notation of std::visit
+// Helper struct defined to use the overloaded notation of std::visit.
 template <class... Ts> struct ParseMethods : Ts... { using Ts::operator()...; };
 template <class... Ts> ParseMethods(Ts...) -> ParseMethods<Ts...>;
 
 bool RootSignatureParser::parseParam(ParamType Ref) {
-  bool Error = true;
-  std::visit(ParseMethods{}, Ref);
-
-  return Error;
+  return std::visit(
+      ParseMethods{
+          [this](Register *X) -> bool { return parseRegister(X); },
+          [this](uint32_t *X) -> bool {
+            return consumeExpectedToken(TokenKind::pu_equal,
+                                        diag::err_expected_after,
+                                        CurToken.Kind) ||
+                   parseUIntParam(X);
+          },
+      },
+      Ref);
 }
 
 bool RootSignatureParser::parseParams(
@@ -169,6 +198,67 @@ bool RootSignatureParser::parseParams(
   return !AllMandatoryDefined;
 }
 
+bool RootSignatureParser::parseUIntParam(uint32_t *X) {
+  assert(CurToken.Kind == TokenKind::pu_equal &&
+         "Expects to only be invoked starting at given keyword");
+  tryConsumeExpectedToken(TokenKind::pu_plus);
+  return consumeExpectedToken(TokenKind::int_literal, diag::err_expected_after,
+                              CurToken.Kind) ||
+         handleUIntLiteral(X);
+}
+
+bool RootSignatureParser::parseRegister(Register *Register) {
+  assert(
+      (CurToken.Kind == TokenKind::bReg || CurToken.Kind == TokenKind::tReg ||
+       CurToken.Kind == TokenKind::uReg || CurToken.Kind == TokenKind::sReg) &&
+      "Expects to only be invoked starting at given keyword");
+
+  switch (CurToken.Kind) {
+  case TokenKind::bReg:
+    Register->ViewType = RegisterType::BReg;
+    break;
+  case TokenKind::tReg:
+    Register->ViewType = RegisterType::TReg;
+    break;
+  case TokenKind::uReg:
+    Register->ViewType = RegisterType::UReg;
+    break;
+  case TokenKind::sReg:
+    Register->ViewType = RegisterType::SReg;
+    break;
+  default:
+    break; // Unreachable given Try + assert pattern
+  }
+
+  if (handleUIntLiteral(&Register->Number))
+    return true; // propogate NumericLiteralParser error
+
+  return false;
+}
+
+bool RootSignatureParser::handleUIntLiteral(uint32_t *X) {
+  // Parse the numeric value and do semantic checks on its specification
+  clang::NumericLiteralParser Literal(CurToken.NumSpelling, CurToken.TokLoc,
+                                      PP.getSourceManager(), PP.getLangOpts(),
+                                      PP.getTargetInfo(), PP.getDiagnostics());
+  if (Literal.hadError)
+    return true; // Error has already been reported so just return
+
+  assert(Literal.isIntegerLiteral() && "IsNumberChar will only support digits");
+
+  llvm::APSInt Val = llvm::APSInt(32, false);
+  if (Literal.GetIntegerValue(Val)) {
+    // Report that the value has overflowed
+    PP.getDiagnostics().Report(CurToken.TokLoc,
+                               diag::err_hlsl_number_literal_overflow)
+        << 0 << CurToken.NumSpelling;
+    return true;
+  }
+
+  *X = Val.getExtValue();
+  return false;
+}
+
 bool RootSignatureParser::peekExpectedToken(TokenKind Expected) {
   return peekExpectedToken(ArrayRef{Expected});
 }
diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index acdf455a5d6aa..5b162e9a1b4cd 100644
--- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -130,10 +130,10 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseEmptyTest) {
 TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
   const llvm::StringLiteral Source = R"cc(
     DescriptorTable(
-      CBV(),
-      SRV(),
-      Sampler(),
-      UAV()
+      CBV(b0),
+      SRV(space = 3, t42),
+      Sampler(s987, space = +2),
+      UAV(u4294967294)
     ),
     DescriptorTable()
   )cc";
@@ -155,18 +155,34 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
   RootElement Elem = Elements[0];
   ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::CBuffer);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.ViewType,
+            RegisterType::BReg);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number, 0u);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 0u);
 
   Elem = Elements[1];
   ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::SRV);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.ViewType,
+            RegisterType::TReg);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number, 42u);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 3u);
 
   Elem = Elements[2];
   ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::Sampler);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.ViewType,
+            RegisterType::SReg);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number, 987u);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 2u);
 
   Elem = Elements[3];
   ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::UAV);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.ViewType,
+            RegisterType::UReg);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number, 4294967294u);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 0u);
 
   Elem = Elements[4];
   ASSERT_TRUE(std::holds_alternative<DescriptorTable>(Elem));
@@ -176,6 +192,32 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
   Elem = Elements[5];
   ASSERT_TRUE(std::holds_alternative<DescriptorTable>(Elem));
   ASSERT_EQ(std::get<DescriptorTable>(Elem).NumClauses, 0u);
+
+  ASSERT_TRUE(Consumer->isSatisfied());
+}
+
+TEST_F(ParseHLSLRootSignatureTest, ValidTrailingCommaTest) {
+  // This test will checks we can handling trailing commas ','
+  const llvm::StringLiteral Source = R"cc(
+    DescriptorTable(
+      CBV(b0, ),
+      SRV(t42),
+    )
+  )cc";
+
+  TrivialModuleLoader ModLoader;
+  auto PP = createPP(Source, ModLoader);
+  auto TokLoc = SourceLocation();
+
+  hlsl::RootSignatureLexer Lexer(Source, TokLoc);
+  SmallVector<RootElement> Elements;
+  hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
+
+  // Test no diagnostics produced
+  Consumer->setNoDiag();
+
+  ASSERT_FALSE(Parser.parse());
+
   ASSERT_TRUE(Consumer->isSatisfied());
 }
 
@@ -237,6 +279,102 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidParseUnexpectedEndOfStreamTest) {
 
   // Test correct diagnostic produced - end of stream
   Consumer->setExpected(diag::err_expected_after);
+
+  ASSERT_TRUE(Parser.parse());
+
+  ASSERT_TRUE(Consumer->isSatisfied());
+}
+
+TEST_F(ParseHLSLRootSignatureTest, InvalidMissingParameterTest) {
+  // This test will check that the parsing fails due a mandatory
+  // parameter (register) not being specified
+  const llvm::StringLiteral Source = R"cc(
+    DescriptorTable(
+      CBV()
+    )
+  )cc";
+
+  TrivialModuleLoader ModLoader;
+  auto PP = createPP(Source, ModLoader);
+  auto TokLoc = SourceLocation();
+
+  hlsl::RootSignatureLexer Lexer(Source, TokLoc);
+  SmallVector<RootElement> Elements;
+  hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
+
+  // Test correct diagnostic produced
+  Consumer->setExpected(diag::err_hlsl_rootsig_missing_param);
+  ASSERT_TRUE(Parser.parse());
+
+  ASSERT_TRUE(Consumer->isSatisfied());
+}
+
+TEST_F(ParseHLSLRootSignatureTest, InvalidRepeatedMandatoryParameterTest) {
+  // This test will check that the parsing fails due the same mandatory
+  // parameter being specified multiple times
+  const llvm::StringLiteral Source = R"cc(
+    DescriptorTable(
+      CBV(b32, b84)
+    )
+  )cc";
+
+  TrivialModuleLoader ModLoader;
+  auto PP = createPP(Source, ModLoader);
+  auto TokLoc = SourceLocation();
+
+  hlsl::RootSignatureLexer Lexer(Source, TokLoc);
+  SmallVector<RootElement> Elements;
+  hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
+
+  // Test correct diagnostic produced
+  Consumer->setExpected(diag::err_hlsl_rootsig_repeat_param);
+  ASSERT_TRUE(Parser.parse());
+
+  ASSERT_TRUE(Consumer->isSatisfied());
+}
+
+TEST_F(ParseHLSLRootSignatureTest, InvalidRepeatedOptionalParameterTest) {
+  // This test will check that the parsing fails due the same optional
+  // parameter being specified multiple times
+  const llvm::StringLiteral Source = R"cc(
+    DescriptorTable(
+      CBV(space = 2, space = 0)
+    )
+  )cc";
+
+  TrivialModuleLoader ModLoader;
+  auto PP = createPP(Source, ModLoader);
+  auto TokLoc = SourceLocation();
+
+  hlsl::RootSignatureLexer Lexer(Source, TokLoc);
+  SmallVector<RootElement> Elements;
+  hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
+
+  // Test correct diagnostic produced
+  Consumer->setExpected(diag::err_hlsl_rootsig_repeat_param);
+  ASSERT_TRUE(Parser.parse());
+
+  ASSERT_TRUE(Consumer->isSatisfied());
+}
+
+TEST_F(ParseHLSLRootSignatureTest, InvalidLexOverflowedNumberTest) {
+  // This test will check that the lexing fails due to an integer overflow
+  const llvm::StringLiteral Source = R"cc(
+    DescriptorTable(
+      CBV(b4294967296)
+    )
+  )cc";
+
+  TrivialModuleLoader ModLoader;
+  auto PP = createPP(Source, ModLoader);
+  auto TokLoc = SourceLocation();
+
+  hlsl::RootSignatureLexer Lexer(Source, TokLoc);
+  SmallVector<RootElement> Elements;
+  hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
+
+  // Test correct diagnostic produced
+  Consumer->setExpected(diag::err_hlsl_number_literal_overflow);
   ASSERT_TRUE(Parser.parse());
 
   ASSERT_TRUE(Consumer->isSatisfied());
diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
index 825beeea961cd..0ae8879b6c7d5 100644
--- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
@@ -23,6 +23,13 @@ namespace rootsig {
 
 // Definitions of the in-memory data layout structures
 
+// Models the different registers: bReg | tReg | uReg | sReg
+enum class RegisterType { BReg, TReg, UReg, SReg };
+struct Register {
+  RegisterType ViewType;
+  uint32_t Number;
+};
+
 // Models the end of a descriptor table and stores its visibility
 struct DescriptorTable {
   uint32_t NumClauses = 0; // The number of clauses in the table
@@ -32,6 +39,8 @@ struct DescriptorTable {
 using ClauseType = llvm::dxil::ResourceClass;
 struct DescriptorTableClause {
   ClauseType Type;
+  Register Register;
+  uint32_t Space = 0;
 };
 
 // Models RootElement : DescriptorTable | DescriptorTableClause
@@ -41,7 +50,7 @@ using RootElement = std::variant<DescriptorTable, DescriptorTableClause>;
 // RootElement. Each variant of ParamType is expected to have a Parse method
 // defined that will be dispatched on when we are attempting to parse a
 // parameter
-using ParamType = std::variant<std::monostate>;
+using ParamType = std::variant<uint32_t *, Register *>;
 
 } // namespace rootsig
 } // namespace hlsl

>From 751e8a11a9efd63b24729994dcf66ee300d99ac3 Mon Sep 17 00:00:00 2001
From: Finn Plummer <finnplummer at microsoft.com>
Date: Mon, 31 Mar 2025 18:47:15 +0000
Subject: [PATCH 3/3] self-review: use consumeExpectedToken api to report
 unexpected_end_of_params diag

---
 clang/lib/Parse/ParseHLSLRootSignature.cpp | 26 +++++++++-------------
 1 file changed, 11 insertions(+), 15 deletions(-)

diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 6c6f154d07daf..62d29baea49d3 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -41,12 +41,11 @@ bool RootSignatureParser::parse() {
       break;
   }
 
-  if (!tryConsumeExpectedToken(TokenKind::end_of_stream)) {
-    getDiags().Report(CurToken.TokLoc, diag::err_hlsl_unexpected_end_of_params)
-        << /*expected=*/TokenKind::end_of_stream
-        << /*param of=*/TokenKind::kw_RootSignature;
+  if (consumeExpectedToken(TokenKind::end_of_stream,
+                           diag::err_hlsl_unexpected_end_of_params,
+                           /*param of=*/TokenKind::kw_RootSignature))
     return true;
-  }
+
   return false;
 }
 
@@ -72,12 +71,10 @@ bool RootSignatureParser::parseDescriptorTable() {
       break;
   }
 
-  if (!tryConsumeExpectedToken(TokenKind::pu_r_paren)) {
-    getDiags().Report(CurToken.TokLoc, diag::err_hlsl_unexpected_end_of_params)
-        << /*expected=*/TokenKind::pu_r_paren
-        << /*param of=*/TokenKind::kw_DescriptorTable;
+  if (consumeExpectedToken(TokenKind::pu_r_paren,
+                           diag::err_hlsl_unexpected_end_of_params,
+                           /*param of=*/TokenKind::kw_DescriptorTable))
     return true;
-  }
 
   Elements.push_back(Table);
   return false;
@@ -129,12 +126,10 @@ bool RootSignatureParser::parseDescriptorTableClause() {
   if (parseParams(Params, Mandatory))
     return true;
 
-  if (!tryConsumeExpectedToken(TokenKind::pu_r_paren)) {
-    getDiags().Report(CurToken.TokLoc, diag::err_hlsl_unexpected_end_of_params)
-        << /*expected=*/TokenKind::pu_r_paren
-        << /*param of=*/ParamKind;
+  if (consumeExpectedToken(TokenKind::pu_r_paren,
+                           diag::err_hlsl_unexpected_end_of_params,
+                           /*param of=*/ParamKind))
     return true;
-  }
 
   Elements.push_back(Clause);
   return false;
@@ -280,6 +275,7 @@ bool RootSignatureParser::consumeExpectedToken(TokenKind Expected,
   case diag::err_expected:
     DB << Expected;
     break;
+  case diag::err_hlsl_unexpected_end_of_params:
   case diag::err_expected_either:
   case diag::err_expected_after:
     DB << Expected << Context;



More information about the llvm-branch-commits mailing list