[llvm-branch-commits] [clang] [llvm] [HLSL][RootSignature] Implement Parsing of Descriptor Tables (PR #122982)

Finn Plummer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Jan 28 13:13:11 PST 2025


https://github.com/inbelic updated https://github.com/llvm/llvm-project/pull/122982

>From 00731b48b819657509cf4633806fa9654c4457f2 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Fri, 24 Jan 2025 22:23:39 +0000
Subject: [PATCH 01/18] [HLSL][RootSignature] Initial Lexer Definition with
 puncuators

- Defines the RootSignatureLexer class
- Defines the test harness required for testing

- Implements the punctuator tokens and tests functionality
---
 .../include/clang/Basic/DiagnosticLexKinds.td |   6 +
 .../Parse/HLSLRootSignatureTokenKinds.def     |  35 ++++
 .../clang/Parse/ParseHLSLRootSignature.h      |  79 +++++++++
 clang/lib/Parse/CMakeLists.txt                |   1 +
 clang/lib/Parse/ParseHLSLRootSignature.cpp    |  50 ++++++
 clang/unittests/CMakeLists.txt                |   1 +
 clang/unittests/Parse/CMakeLists.txt          |  26 +++
 .../Parse/ParseHLSLRootSignatureTest.cpp      | 167 ++++++++++++++++++
 8 files changed, 365 insertions(+)
 create mode 100644 clang/include/clang/Parse/HLSLRootSignatureTokenKinds.def
 create mode 100644 clang/include/clang/Parse/ParseHLSLRootSignature.h
 create mode 100644 clang/lib/Parse/ParseHLSLRootSignature.cpp
 create mode 100644 clang/unittests/Parse/CMakeLists.txt
 create mode 100644 clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp

diff --git a/clang/include/clang/Basic/DiagnosticLexKinds.td b/clang/include/clang/Basic/DiagnosticLexKinds.td
index 959376b0847216..7755c05bc8969b 100644
--- a/clang/include/clang/Basic/DiagnosticLexKinds.td
+++ b/clang/include/clang/Basic/DiagnosticLexKinds.td
@@ -1017,4 +1017,10 @@ Error<"'#pragma unsafe_buffer_usage' was not ended">;
 
 def err_pp_pragma_unsafe_buffer_usage_syntax :
 Error<"expected 'begin' or 'end'">;
+
+// HLSL Root Signature Lexing Errors
+let CategoryName = "Root Signature Lexical Issue" in {
+  def err_hlsl_invalid_token: Error<"unable to lex a valid Root Signature token">;
+}
+
 }
diff --git a/clang/include/clang/Parse/HLSLRootSignatureTokenKinds.def b/clang/include/clang/Parse/HLSLRootSignatureTokenKinds.def
new file mode 100644
index 00000000000000..9625f6a5bd76d9
--- /dev/null
+++ b/clang/include/clang/Parse/HLSLRootSignatureTokenKinds.def
@@ -0,0 +1,35 @@
+//===--- HLSLRootSignature.def - Tokens and Enum Database -------*- C++ -*-===//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the TokenKinds used in the Root Signature DSL. This
+// includes keywords, enums and a small subset of punctuators. Users of this
+// file must optionally #define the TOK, KEYWORD, ENUM or specific ENUM macros
+// to make use of this file.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TOK
+#define TOK(X)
+#endif
+#ifndef PUNCTUATOR
+#define PUNCTUATOR(X,Y) TOK(pu_ ## X)
+#endif
+
+// General Tokens:
+TOK(invalid)
+
+// Punctuators:
+PUNCTUATOR(l_paren, '(')
+PUNCTUATOR(r_paren, ')')
+PUNCTUATOR(comma,   ',')
+PUNCTUATOR(or,      '|')
+PUNCTUATOR(equal,   '=')
+
+#undef PUNCTUATOR
+#undef TOK
diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
new file mode 100644
index 00000000000000..39069e7cc39988
--- /dev/null
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -0,0 +1,79 @@
+//===--- ParseHLSLRootSignature.h -------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+//  This file defines the ParseHLSLRootSignature interface.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_PARSE_PARSEHLSLROOTSIGNATURE_H
+#define LLVM_CLANG_PARSE_PARSEHLSLROOTSIGNATURE_H
+
+#include "clang/Basic/DiagnosticLex.h"
+#include "clang/Lex/Preprocessor.h"
+
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace clang {
+namespace hlsl {
+
+struct RootSignatureToken {
+  enum Kind {
+#define TOK(X) X,
+#include "clang/Parse/HLSLRootSignatureTokenKinds.def"
+  };
+
+  Kind Kind = Kind::invalid;
+
+  // Retain the SouceLocation of the token for diagnostics
+  clang::SourceLocation TokLoc;
+
+  // Constructors
+  RootSignatureToken(clang::SourceLocation TokLoc) : TokLoc(TokLoc) {}
+};
+using TokenKind = enum RootSignatureToken::Kind;
+
+class RootSignatureLexer {
+public:
+  RootSignatureLexer(StringRef Signature, clang::SourceLocation SourceLoc,
+                     clang::Preprocessor &PP)
+      : Buffer(Signature), SourceLoc(SourceLoc), PP(PP) {}
+
+  // Consumes the internal buffer as a list of tokens and will emplace them
+  // onto the given tokens.
+  //
+  // It will consume until it successfully reaches the end of the buffer,
+  // or, until the first error is encountered. The return value denotes if
+  // there was a failure.
+  bool Lex(SmallVector<RootSignatureToken> &Tokens);
+
+private:
+  // Internal buffer to iterate over
+  StringRef Buffer;
+
+  // Passed down parameters from Sema
+  clang::SourceLocation SourceLoc;
+  clang::Preprocessor &PP;
+
+  // Consumes the internal buffer for a single token.
+  //
+  // The return value denotes if there was a failure.
+  bool LexToken(RootSignatureToken &Token);
+
+  // Advance the buffer by the specified number of characters. Updates the
+  // SourceLocation appropriately.
+  void AdvanceBuffer(unsigned NumCharacters = 1) {
+    Buffer = Buffer.drop_front(NumCharacters);
+    SourceLoc = SourceLoc.getLocWithOffset(NumCharacters);
+  }
+};
+
+} // namespace hlsl
+} // namespace clang
+
+#endif // LLVM_CLANG_PARSE_PARSEHLSLROOTSIGNATURE_H
diff --git a/clang/lib/Parse/CMakeLists.txt b/clang/lib/Parse/CMakeLists.txt
index 22e902f7e1bc50..00fde537bb9c60 100644
--- a/clang/lib/Parse/CMakeLists.txt
+++ b/clang/lib/Parse/CMakeLists.txt
@@ -14,6 +14,7 @@ add_clang_library(clangParse
   ParseExpr.cpp
   ParseExprCXX.cpp
   ParseHLSL.cpp
+  ParseHLSLRootSignature.cpp
   ParseInit.cpp
   ParseObjc.cpp
   ParseOpenMP.cpp
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
new file mode 100644
index 00000000000000..a9a9d209085c91
--- /dev/null
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -0,0 +1,50 @@
+#include "clang/Parse/ParseHLSLRootSignature.h"
+
+namespace clang {
+namespace hlsl {
+
+// Lexer Definitions
+
+bool RootSignatureLexer::Lex(SmallVector<RootSignatureToken> &Tokens) {
+  // Discard any leading whitespace
+  AdvanceBuffer(Buffer.take_while(isspace).size());
+
+  while (!Buffer.empty()) {
+    // Record where this token is in the text for usage in parser diagnostics
+    RootSignatureToken Result(SourceLoc);
+    if (LexToken(Result))
+      return true;
+
+    // Successfully Lexed the token so we can store it
+    Tokens.push_back(Result);
+
+    // Discard any trailing whitespace
+    AdvanceBuffer(Buffer.take_while(isspace).size());
+  }
+
+  return false;
+}
+
+bool RootSignatureLexer::LexToken(RootSignatureToken &Result) {
+  char C = Buffer.front();
+
+  // Punctuators
+  switch (C) {
+#define PUNCTUATOR(X, Y)                                                       \
+  case Y: {                                                                    \
+    Result.Kind = TokenKind::pu_##X;                                           \
+    AdvanceBuffer();                                                           \
+    return false;                                                              \
+  }
+#include "clang/Parse/HLSLRootSignatureTokenKinds.def"
+  default:
+    break;
+  }
+
+  // Unable to match on any token type
+  PP.getDiagnostics().Report(Result.TokLoc, diag::err_hlsl_invalid_token);
+  return true;
+}
+
+} // namespace hlsl
+} // namespace clang
diff --git a/clang/unittests/CMakeLists.txt b/clang/unittests/CMakeLists.txt
index 85d265426ec80b..9b3ce8aa7de739 100644
--- a/clang/unittests/CMakeLists.txt
+++ b/clang/unittests/CMakeLists.txt
@@ -25,6 +25,7 @@ endfunction()
 
 add_subdirectory(Basic)
 add_subdirectory(Lex)
+add_subdirectory(Parse)
 add_subdirectory(Driver)
 if(CLANG_ENABLE_STATIC_ANALYZER)
   add_subdirectory(Analysis)
diff --git a/clang/unittests/Parse/CMakeLists.txt b/clang/unittests/Parse/CMakeLists.txt
new file mode 100644
index 00000000000000..1b7eb4934a46c8
--- /dev/null
+++ b/clang/unittests/Parse/CMakeLists.txt
@@ -0,0 +1,26 @@
+set(LLVM_LINK_COMPONENTS
+  Support
+  )
+
+add_clang_unittest(ParseTests
+  ParseHLSLRootSignatureTest.cpp
+  )
+
+clang_target_link_libraries(ParseTests
+  PRIVATE
+  clangAST
+  clangASTMatchers
+  clangBasic
+  clangFrontend
+  clangParse
+  clangSema
+  clangSerialization
+  clangTooling
+  )
+
+target_link_libraries(ParseTests
+  PRIVATE
+  LLVMTestingAnnotations
+  LLVMTestingSupport
+  clangTesting
+  )
diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
new file mode 100644
index 00000000000000..e5f88bbfa0ff6a
--- /dev/null
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -0,0 +1,167 @@
+//=== ParseHLSLRootSignatureTest.cpp - Parse Root Signature tests ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/Basic/Diagnostic.h"
+#include "clang/Basic/DiagnosticOptions.h"
+#include "clang/Basic/FileManager.h"
+#include "clang/Basic/LangOptions.h"
+#include "clang/Basic/SourceLocation.h"
+#include "clang/Basic/SourceManager.h"
+#include "clang/Basic/TargetInfo.h"
+#include "clang/Lex/HeaderSearch.h"
+#include "clang/Lex/HeaderSearchOptions.h"
+#include "clang/Lex/Lexer.h"
+#include "clang/Lex/ModuleLoader.h"
+#include "clang/Lex/Preprocessor.h"
+#include "clang/Lex/PreprocessorOptions.h"
+
+#include "clang/Parse/ParseHLSLRootSignature.h"
+#include "gtest/gtest.h"
+
+using namespace clang;
+
+namespace {
+
+// Diagnostic helper for helper tests
+class ExpectedDiagConsumer : public DiagnosticConsumer {
+  virtual void anchor() {}
+
+  void HandleDiagnostic(DiagnosticsEngine::Level DiagLevel,
+                        const Diagnostic &Info) override {
+    if (!FirstDiag || !ExpectedDiagID.has_value()) {
+      Satisfied = false;
+      return;
+    }
+    FirstDiag = false;
+
+    Satisfied = ExpectedDiagID.value() == Info.getID();
+  }
+
+  bool FirstDiag = true;
+  bool Satisfied = false;
+  std::optional<unsigned> ExpectedDiagID;
+
+public:
+  void SetNoDiag() {
+    Satisfied = true;
+    ExpectedDiagID = std::nullopt;
+  }
+
+  void SetExpected(unsigned DiagID) {
+    Satisfied = false;
+    ExpectedDiagID = DiagID;
+  }
+
+  bool IsSatisfied() { return Satisfied; }
+};
+
+// The test fixture.
+class ParseHLSLRootSignatureTest : public ::testing::Test {
+protected:
+  ParseHLSLRootSignatureTest()
+      : FileMgr(FileMgrOpts), DiagID(new DiagnosticIDs()),
+        Consumer(new ExpectedDiagConsumer()),
+        Diags(DiagID, new DiagnosticOptions, Consumer),
+        SourceMgr(Diags, FileMgr), TargetOpts(new TargetOptions) {
+    TargetOpts->Triple = "x86_64-apple-darwin11.1.0";
+    Target = TargetInfo::CreateTargetInfo(Diags, TargetOpts);
+  }
+
+  std::unique_ptr<Preprocessor> CreatePP(StringRef Source,
+                                         TrivialModuleLoader &ModLoader) {
+    std::unique_ptr<llvm::MemoryBuffer> Buf =
+        llvm::MemoryBuffer::getMemBuffer(Source);
+    SourceMgr.setMainFileID(SourceMgr.createFileID(std::move(Buf)));
+
+    HeaderSearch HeaderInfo(std::make_shared<HeaderSearchOptions>(), SourceMgr,
+                            Diags, LangOpts, Target.get());
+    std::unique_ptr<Preprocessor> PP = std::make_unique<Preprocessor>(
+        std::make_shared<PreprocessorOptions>(), Diags, LangOpts, SourceMgr,
+        HeaderInfo, ModLoader,
+        /*IILookup =*/nullptr,
+        /*OwnsHeaderSearch =*/false);
+    PP->Initialize(*Target);
+    PP->EnterMainSourceFile();
+    return PP;
+  }
+
+  void CheckTokens(SmallVector<hlsl::RootSignatureToken> &Computed,
+                   SmallVector<hlsl::TokenKind> &Expected) {
+    ASSERT_EQ(Computed.size(), Expected.size());
+    for (unsigned I = 0, E = Expected.size(); I != E; ++I) {
+      ASSERT_EQ(Computed[I].Kind, Expected[I]);
+    }
+  }
+
+  FileSystemOptions FileMgrOpts;
+  FileManager FileMgr;
+  IntrusiveRefCntPtr<DiagnosticIDs> DiagID;
+  ExpectedDiagConsumer *Consumer;
+  DiagnosticsEngine Diags;
+  SourceManager SourceMgr;
+  LangOptions LangOpts;
+  std::shared_ptr<TargetOptions> TargetOpts;
+  IntrusiveRefCntPtr<TargetInfo> Target;
+};
+
+// Valid Lexing Tests
+
+TEST_F(ParseHLSLRootSignatureTest, ValidLexAllTokensTest) {
+  // This test will check that we can lex all defined tokens as defined in
+  // HLSLRootSignatureTokenKinds.def, plus some additional integer variations
+  const llvm::StringLiteral Source = R"cc(
+    (),|=
+  )cc";
+
+  TrivialModuleLoader ModLoader;
+  auto PP = CreatePP(Source, ModLoader);
+  auto TokLoc = SourceLocation();
+
+  // Test no diagnostics produced
+  Consumer->SetNoDiag();
+
+  hlsl::RootSignatureLexer Lexer(Source, TokLoc, *PP);
+
+  SmallVector<hlsl::RootSignatureToken> Tokens = {
+      hlsl::RootSignatureToken(
+          SourceLocation()) // invalid token for completeness
+  };
+  ASSERT_FALSE(Lexer.Lex(Tokens));
+  ASSERT_TRUE(Consumer->IsSatisfied());
+
+  SmallVector<hlsl::TokenKind> Expected = {
+#define TOK(NAME) hlsl::TokenKind::NAME,
+#include "clang/Parse/HLSLRootSignatureTokenKinds.def"
+  };
+
+  CheckTokens(Tokens, Expected);
+}
+
+// Invalid Lexing Tests
+
+TEST_F(ParseHLSLRootSignatureTest, InvalidLexIdentifierTest) {
+  // This test will check that the lexing fails due to no valid token
+  const llvm::StringLiteral Source = R"cc(
+    notAToken
+  )cc";
+
+  TrivialModuleLoader ModLoader;
+  auto PP = CreatePP(Source, ModLoader);
+  auto TokLoc = SourceLocation();
+
+  // Test correct diagnostic produced
+  Consumer->SetExpected(diag::err_hlsl_invalid_token);
+
+  hlsl::RootSignatureLexer Lexer(Source, TokLoc, *PP);
+
+  SmallVector<hlsl::RootSignatureToken> Tokens;
+  ASSERT_TRUE(Lexer.Lex(Tokens));
+  ASSERT_TRUE(Consumer->IsSatisfied());
+}
+
+} // anonymous namespace

>From 1f0d54fa60ba8b72ca001dad8aa85ff05f40749a Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Fri, 24 Jan 2025 22:28:05 +0000
Subject: [PATCH 02/18] Add lexing of integer literals

- Integrate the use of the `NumericLiteralParser` to lex integer
literals
- Add additional hlsl specific diagnostics messages
---
 .../include/clang/Basic/DiagnosticLexKinds.td |  4 ++
 .../Parse/HLSLRootSignatureTokenKinds.def     |  1 +
 .../clang/Parse/ParseHLSLRootSignature.h      |  6 ++
 clang/lib/Parse/ParseHLSLRootSignature.cpp    | 59 ++++++++++++++++
 .../Parse/ParseHLSLRootSignatureTest.cpp      | 69 +++++++++++++++++++
 5 files changed, 139 insertions(+)

diff --git a/clang/include/clang/Basic/DiagnosticLexKinds.td b/clang/include/clang/Basic/DiagnosticLexKinds.td
index 7755c05bc8969b..81d314c838cede 100644
--- a/clang/include/clang/Basic/DiagnosticLexKinds.td
+++ b/clang/include/clang/Basic/DiagnosticLexKinds.td
@@ -1020,6 +1020,10 @@ Error<"expected 'begin' or 'end'">;
 
 // HLSL Root Signature Lexing Errors
 let CategoryName = "Root Signature Lexical Issue" in {
+  def err_hlsl_invalid_number_literal:
+    Error<"expected number literal is not a supported number literal of unsigned integer or integer">;
+  def err_hlsl_number_literal_overflow :
+    Error<"provided %select{unsigned integer|signed integer}0 literal '%1' that overflows the maximum of 32 bits">;
   def err_hlsl_invalid_token: Error<"unable to lex a valid Root Signature token">;
 }
 
diff --git a/clang/include/clang/Parse/HLSLRootSignatureTokenKinds.def b/clang/include/clang/Parse/HLSLRootSignatureTokenKinds.def
index 9625f6a5bd76d9..64c5fd14a2017f 100644
--- a/clang/include/clang/Parse/HLSLRootSignatureTokenKinds.def
+++ b/clang/include/clang/Parse/HLSLRootSignatureTokenKinds.def
@@ -23,6 +23,7 @@
 
 // General Tokens:
 TOK(invalid)
+TOK(int_literal)
 
 // Punctuators:
 PUNCTUATOR(l_paren, '(')
diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index 39069e7cc39988..3b2c61781b1da3 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -13,7 +13,9 @@
 #ifndef LLVM_CLANG_PARSE_PARSEHLSLROOTSIGNATURE_H
 #define LLVM_CLANG_PARSE_PARSEHLSLROOTSIGNATURE_H
 
+#include "clang/AST/APValue.h"
 #include "clang/Basic/DiagnosticLex.h"
+#include "clang/Lex/LiteralSupport.h"
 #include "clang/Lex/Preprocessor.h"
 
 #include "llvm/ADT/SmallVector.h"
@@ -33,6 +35,8 @@ struct RootSignatureToken {
   // Retain the SouceLocation of the token for diagnostics
   clang::SourceLocation TokLoc;
 
+  APValue NumLiteral = APValue();
+
   // Constructors
   RootSignatureToken(clang::SourceLocation TokLoc) : TokLoc(TokLoc) {}
 };
@@ -60,6 +64,8 @@ class RootSignatureLexer {
   clang::SourceLocation SourceLoc;
   clang::Preprocessor &PP;
 
+  bool LexNumber(RootSignatureToken &Result);
+
   // Consumes the internal buffer for a single token.
   //
   // The return value denotes if there was a failure.
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index a9a9d209085c91..1cc75998ca07da 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -5,6 +5,61 @@ namespace hlsl {
 
 // Lexer Definitions
 
+static bool IsNumberChar(char C) {
+  // TODO(#120472): extend for float support exponents
+  return isdigit(C); // integer support
+}
+
+bool RootSignatureLexer::LexNumber(RootSignatureToken &Result) {
+  // NumericLiteralParser does not handle the sign so we will manually apply it
+  bool Negative = Buffer.front() == '-';
+  bool Signed = Negative || Buffer.front() == '+';
+  if (Signed)
+    AdvanceBuffer();
+
+  // Retrieve the possible number
+  StringRef NumSpelling = Buffer.take_while(IsNumberChar);
+
+  // Catch this now as the Literal Parser will accept it as valid
+  if (NumSpelling.empty()) {
+    PP.getDiagnostics().Report(Result.TokLoc,
+                               diag::err_hlsl_invalid_number_literal);
+    return true;
+  }
+
+  // Parse the numeric value and do semantic checks on its specification
+  clang::NumericLiteralParser Literal(NumSpelling, SourceLoc,
+                                      PP.getSourceManager(), PP.getLangOpts(),
+                                      PP.getTargetInfo(), PP.getDiagnostics());
+  if (Literal.hadError)
+    return true; // Error has already been reported so just return
+
+  if (!Literal.isIntegerLiteral()) {
+    // Note: if IsNumberChar allows for hexidecimal we will need to turn this
+    // into a diagnostics for potential fixed-point literals
+    llvm_unreachable("IsNumberChar will only support digits");
+    return true;
+  }
+
+  // Retrieve the number value to store into the token
+  Result.Kind = TokenKind::int_literal;
+
+  llvm::APSInt X = llvm::APSInt(32, !Signed);
+  if (Literal.GetIntegerValue(X)) {
+    // Report that the value has overflowed
+    PP.getDiagnostics().Report(Result.TokLoc,
+                               diag::err_hlsl_number_literal_overflow)
+        << (unsigned)Signed << NumSpelling;
+    return true;
+  }
+
+  X = Negative ? -X : X;
+  Result.NumLiteral = APValue(X);
+
+  AdvanceBuffer(NumSpelling.size());
+  return false;
+}
+
 bool RootSignatureLexer::Lex(SmallVector<RootSignatureToken> &Tokens) {
   // Discard any leading whitespace
   AdvanceBuffer(Buffer.take_while(isspace).size());
@@ -41,6 +96,10 @@ bool RootSignatureLexer::LexToken(RootSignatureToken &Result) {
     break;
   }
 
+  // Numeric constant
+  if (isdigit(C) || C == '-' || C == '+')
+    return LexNumber(Result);
+
   // Unable to match on any token type
   PP.getDiagnostics().Report(Result.TokLoc, diag::err_hlsl_invalid_token);
   return true;
diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index e5f88bbfa0ff6a..713bc5b1257f74 100644
--- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -111,10 +111,39 @@ class ParseHLSLRootSignatureTest : public ::testing::Test {
 
 // Valid Lexing Tests
 
+TEST_F(ParseHLSLRootSignatureTest, ValidLexNumbersTest) {
+  // This test will check that we can lex different number tokens
+  const llvm::StringLiteral Source = R"cc(
+    -42 42 +42
+  )cc";
+
+  TrivialModuleLoader ModLoader;
+  auto PP = CreatePP(Source, ModLoader);
+  auto TokLoc = SourceLocation();
+
+  // Test no diagnostics produced
+  Consumer->SetNoDiag();
+
+  hlsl::RootSignatureLexer Lexer(Source, TokLoc, *PP);
+
+  SmallVector<hlsl::RootSignatureToken> Tokens;
+  ASSERT_FALSE(Lexer.Lex(Tokens));
+  ASSERT_TRUE(Consumer->IsSatisfied());
+
+  SmallVector<hlsl::TokenKind> Expected = {
+      hlsl::TokenKind::int_literal,
+      hlsl::TokenKind::int_literal,
+      hlsl::TokenKind::int_literal,
+  };
+  CheckTokens(Tokens, Expected);
+}
+
 TEST_F(ParseHLSLRootSignatureTest, ValidLexAllTokensTest) {
   // This test will check that we can lex all defined tokens as defined in
   // HLSLRootSignatureTokenKinds.def, plus some additional integer variations
   const llvm::StringLiteral Source = R"cc(
+    42
+
     (),|=
   )cc";
 
@@ -144,6 +173,46 @@ TEST_F(ParseHLSLRootSignatureTest, ValidLexAllTokensTest) {
 
 // Invalid Lexing Tests
 
+TEST_F(ParseHLSLRootSignatureTest, InvalidLexOverflowedNumberTest) {
+  // This test will check that the lexing fails due to an integer overflow
+  const llvm::StringLiteral Source = R"cc(
+    4294967296
+  )cc";
+
+  TrivialModuleLoader ModLoader;
+  auto PP = CreatePP(Source, ModLoader);
+  auto TokLoc = SourceLocation();
+
+  // Test correct diagnostic produced
+  Consumer->SetExpected(diag::err_hlsl_number_literal_overflow);
+
+  hlsl::RootSignatureLexer Lexer(Source, TokLoc, *PP);
+
+  SmallVector<hlsl::RootSignatureToken> Tokens;
+  ASSERT_TRUE(Lexer.Lex(Tokens));
+  ASSERT_TRUE(Consumer->IsSatisfied());
+}
+
+TEST_F(ParseHLSLRootSignatureTest, InvalidLexEmptyNumberTest) {
+  // This test will check that the lexing fails due to no integer being provided
+  const llvm::StringLiteral Source = R"cc(
+    -
+  )cc";
+
+  TrivialModuleLoader ModLoader;
+  auto PP = CreatePP(Source, ModLoader);
+  auto TokLoc = SourceLocation();
+
+  // Test correct diagnostic produced
+  Consumer->SetExpected(diag::err_hlsl_invalid_number_literal);
+
+  hlsl::RootSignatureLexer Lexer(Source, TokLoc, *PP);
+
+  SmallVector<hlsl::RootSignatureToken> Tokens;
+  ASSERT_TRUE(Lexer.Lex(Tokens));
+  ASSERT_TRUE(Consumer->IsSatisfied());
+}
+
 TEST_F(ParseHLSLRootSignatureTest, InvalidLexIdentifierTest) {
   // This test will check that the lexing fails due to no valid token
   const llvm::StringLiteral Source = R"cc(

>From e80401b980b9d3a3163e0338578fbddc0b7fbe87 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Fri, 24 Jan 2025 21:04:36 +0000
Subject: [PATCH 03/18] Add support for lexing registers

---
 .../include/clang/Basic/DiagnosticLexKinds.td |  1 +
 .../Parse/HLSLRootSignatureTokenKinds.def     |  6 +++
 clang/lib/Parse/ParseHLSLRootSignature.cpp    | 47 +++++++++++++++++++
 .../Parse/ParseHLSLRootSignatureTest.cpp      | 23 +++++++++
 4 files changed, 77 insertions(+)

diff --git a/clang/include/clang/Basic/DiagnosticLexKinds.td b/clang/include/clang/Basic/DiagnosticLexKinds.td
index 81d314c838cede..cbadd16df3f9db 100644
--- a/clang/include/clang/Basic/DiagnosticLexKinds.td
+++ b/clang/include/clang/Basic/DiagnosticLexKinds.td
@@ -1025,6 +1025,7 @@ let CategoryName = "Root Signature Lexical Issue" in {
   def err_hlsl_number_literal_overflow :
     Error<"provided %select{unsigned integer|signed integer}0 literal '%1' that overflows the maximum of 32 bits">;
   def err_hlsl_invalid_token: Error<"unable to lex a valid Root Signature token">;
+  def err_hlsl_invalid_register_literal: Error<"expected unsigned integer literal as part of register">;
 }
 
 }
diff --git a/clang/include/clang/Parse/HLSLRootSignatureTokenKinds.def b/clang/include/clang/Parse/HLSLRootSignatureTokenKinds.def
index 64c5fd14a2017f..fc4dbfef728798 100644
--- a/clang/include/clang/Parse/HLSLRootSignatureTokenKinds.def
+++ b/clang/include/clang/Parse/HLSLRootSignatureTokenKinds.def
@@ -25,6 +25,12 @@
 TOK(invalid)
 TOK(int_literal)
 
+// Register Tokens:
+TOK(bReg)
+TOK(tReg)
+TOK(uReg)
+TOK(sReg)
+
 // Punctuators:
 PUNCTUATOR(l_paren, '(')
 PUNCTUATOR(r_paren, ')')
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 1cc75998ca07da..1b9973beb44170 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -100,6 +100,53 @@ bool RootSignatureLexer::LexToken(RootSignatureToken &Result) {
   if (isdigit(C) || C == '-' || C == '+')
     return LexNumber(Result);
 
+  // All following tokens require at least one additional character
+  if (Buffer.size() <= 1) {
+    PP.getDiagnostics().Report(Result.TokLoc, diag::err_hlsl_invalid_token);
+    return true;
+  }
+
+  // Peek at the next character to deteremine token type
+  char NextC = Buffer[1];
+
+  // Registers: [tsub][0-9+]
+  if ((C == 't' || C == 's' || C == 'u' || C == 'b') && isdigit(NextC)) {
+    AdvanceBuffer();
+
+    if (LexNumber(Result))
+      return true; // Error parsing number which is already reported
+
+    // Lex number could also parse a float so ensure it was an unsigned int
+    if (Result.Kind != TokenKind::int_literal ||
+        Result.NumLiteral.getInt().isSigned()) {
+      // Return invalid number literal for register error
+      PP.getDiagnostics().Report(Result.TokLoc,
+                                 diag::err_hlsl_invalid_register_literal);
+      return true;
+    }
+
+    // Convert character to the register type.
+    // This is done after LexNumber to override the TokenKind
+    switch (C) {
+    case 'b':
+      Result.Kind = TokenKind::bReg;
+      break;
+    case 't':
+      Result.Kind = TokenKind::tReg;
+      break;
+    case 'u':
+      Result.Kind = TokenKind::uReg;
+      break;
+    case 's':
+      Result.Kind = TokenKind::sReg;
+      break;
+    default:
+      llvm_unreachable("Switch for an expected token was not provided");
+      return true;
+    }
+    return false;
+  }
+
   // Unable to match on any token type
   PP.getDiagnostics().Report(Result.TokLoc, diag::err_hlsl_invalid_token);
   return true;
diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index 713bc5b1257f74..47195099ed60df 100644
--- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -144,6 +144,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidLexAllTokensTest) {
   const llvm::StringLiteral Source = R"cc(
     42
 
+    b0 t43 u987 s234
+
     (),|=
   )cc";
 
@@ -213,6 +215,27 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidLexEmptyNumberTest) {
   ASSERT_TRUE(Consumer->IsSatisfied());
 }
 
+TEST_F(ParseHLSLRootSignatureTest, InvalidLexRegNumberTest) {
+  // This test will check that the lexing fails due to no integer being provided
+  const llvm::StringLiteral Source = R"cc(
+    b32.4
+  )cc";
+
+  TrivialModuleLoader ModLoader;
+  auto PP = CreatePP(Source, ModLoader);
+  auto TokLoc = SourceLocation();
+
+  // Test correct diagnostic produced
+  Consumer->SetExpected(diag::err_hlsl_invalid_register_literal);
+
+  hlsl::RootSignatureLexer Lexer(Source, TokLoc, *PP);
+
+  SmallVector<hlsl::RootSignatureToken> Tokens;
+  ASSERT_TRUE(Lexer.Lex(Tokens));
+  // FIXME(#120472): This should be TRUE once we can lex a floating
+  ASSERT_FALSE(Consumer->IsSatisfied());
+}
+
 TEST_F(ParseHLSLRootSignatureTest, InvalidLexIdentifierTest) {
   // This test will check that the lexing fails due to no valid token
   const llvm::StringLiteral Source = R"cc(

>From 43ff4461ab0ce5e8c7e722a8b9c21f214f6b4bfd Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Fri, 24 Jan 2025 21:40:23 +0000
Subject: [PATCH 04/18] Add lexing for example keyword and enum

---
 .../Parse/HLSLRootSignatureTokenKinds.def     | 20 ++++++++++++++++
 .../clang/Parse/ParseHLSLRootSignature.h      |  1 +
 clang/lib/Parse/ParseHLSLRootSignature.cpp    | 23 ++++++++++++++++---
 .../Parse/ParseHLSLRootSignatureTest.cpp      |  4 ++++
 4 files changed, 45 insertions(+), 3 deletions(-)

diff --git a/clang/include/clang/Parse/HLSLRootSignatureTokenKinds.def b/clang/include/clang/Parse/HLSLRootSignatureTokenKinds.def
index fc4dbfef728798..d73c9adbb94e5c 100644
--- a/clang/include/clang/Parse/HLSLRootSignatureTokenKinds.def
+++ b/clang/include/clang/Parse/HLSLRootSignatureTokenKinds.def
@@ -20,6 +20,17 @@
 #ifndef PUNCTUATOR
 #define PUNCTUATOR(X,Y) TOK(pu_ ## X)
 #endif
+#ifndef KEYWORD
+#define KEYWORD(X) TOK(kw_ ## X)
+#endif
+#ifndef ENUM
+#define ENUM(NAME, LIT) TOK(en_ ## NAME)
+#endif
+
+// Defines the various types of enum
+#ifndef DESCRIPTOR_RANGE_OFFSET_ENUM
+#define DESCRIPTOR_RANGE_OFFSET_ENUM(NAME, LIT) ENUM(NAME, LIT)
+#endif
 
 // General Tokens:
 TOK(invalid)
@@ -38,5 +49,14 @@ PUNCTUATOR(comma,   ',')
 PUNCTUATOR(or,      '|')
 PUNCTUATOR(equal,   '=')
 
+// RootElement Keywords:
+KEYWORD(DescriptorTable)
+
+// Descriptor Range Offset Enum:
+DESCRIPTOR_RANGE_OFFSET_ENUM(DescriptorRangeOffsetAppend, "DESCRIPTOR_RANGE_OFFSET_APPEND")
+
+#undef DESCRIPTOR_RANGE_OFFSET_ENUM
+#undef ENUM
+#undef KEYWORD
 #undef PUNCTUATOR
 #undef TOK
diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index 3b2c61781b1da3..899608bd1527ea 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -20,6 +20,7 @@
 
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/StringSwitch.h"
 
 namespace clang {
 namespace hlsl {
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 1b9973beb44170..7ceb85a47a088e 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -147,9 +147,26 @@ bool RootSignatureLexer::LexToken(RootSignatureToken &Result) {
     return false;
   }
 
-  // Unable to match on any token type
-  PP.getDiagnostics().Report(Result.TokLoc, diag::err_hlsl_invalid_token);
-  return true;
+  // Keywords and Enums:
+  StringRef TokSpelling =
+      Buffer.take_while([](char C) { return isalnum(C) || C == '_'; });
+
+  // Define a large string switch statement for all the keywords and enums
+  auto Switch = llvm::StringSwitch<TokenKind>(TokSpelling);
+#define KEYWORD(NAME) Switch.Case(#NAME, TokenKind::kw_##NAME);
+#define ENUM(NAME, LIT) Switch.CaseLower(LIT, TokenKind::en_##NAME);
+#include "clang/Parse/HLSLRootSignatureTokenKinds.def"
+
+  // Then attempt to retreive a string from it
+  auto Kind = Switch.Default(TokenKind::invalid);
+  if (Kind == TokenKind::invalid) {
+    PP.getDiagnostics().Report(Result.TokLoc, diag::err_hlsl_invalid_token);
+    return true;
+  }
+
+  Result.Kind = Kind;
+  AdvanceBuffer(TokSpelling.size());
+  return false;
 }
 
 } // namespace hlsl
diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index 47195099ed60df..d80ded10ba313c 100644
--- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -147,6 +147,10 @@ TEST_F(ParseHLSLRootSignatureTest, ValidLexAllTokensTest) {
     b0 t43 u987 s234
 
     (),|=
+
+    DescriptorTable
+
+    DESCRIPTOR_RANGE_OFFSET_APPEND
   )cc";
 
   TrivialModuleLoader ModLoader;

>From 392d5b006656d37713fac15cc87b9b790b3c93d4 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Fri, 24 Jan 2025 21:44:24 +0000
Subject: [PATCH 05/18] Add lexing for remaining DescriptorTable keywords and
 enums

---
 .../Parse/HLSLRootSignatureTokenKinds.def     | 59 +++++++++++++++++++
 .../Parse/ParseHLSLRootSignatureTest.cpp      | 19 ++++++
 2 files changed, 78 insertions(+)

diff --git a/clang/include/clang/Parse/HLSLRootSignatureTokenKinds.def b/clang/include/clang/Parse/HLSLRootSignatureTokenKinds.def
index d73c9adbb94e5c..5e47af8d4b1364 100644
--- a/clang/include/clang/Parse/HLSLRootSignatureTokenKinds.def
+++ b/clang/include/clang/Parse/HLSLRootSignatureTokenKinds.def
@@ -31,6 +31,23 @@
 #ifndef DESCRIPTOR_RANGE_OFFSET_ENUM
 #define DESCRIPTOR_RANGE_OFFSET_ENUM(NAME, LIT) ENUM(NAME, LIT)
 #endif
+#ifndef ROOT_DESCRIPTOR_FLAG_ENUM
+#define ROOT_DESCRIPTOR_FLAG_ENUM(NAME, LIT) ENUM(NAME, LIT)
+#endif
+// Note: ON denotes that the flag is unique from the above Root Descriptor
+//  Flags. This is required to avoid token kind enum conflicts.
+#ifndef DESCRIPTOR_RANGE_FLAG_ENUM_OFF
+#define DESCRIPTOR_RANGE_FLAG_ENUM_OFF(NAME, LIT)
+#endif
+#ifndef DESCRIPTOR_RANGE_FLAG_ENUM_ON
+#define DESCRIPTOR_RANGE_FLAG_ENUM_ON(NAME, LIT) ENUM(NAME, LIT)
+#endif
+#ifndef DESCRIPTOR_RANGE_FLAG_ENUM
+#define DESCRIPTOR_RANGE_FLAG_ENUM(NAME, LIT, ON) DESCRIPTOR_RANGE_FLAG_ENUM_##ON(NAME, LIT)
+#endif
+#ifndef SHADER_VISIBILITY_ENUM
+#define SHADER_VISIBILITY_ENUM(NAME, LIT) ENUM(NAME, LIT)
+#endif
 
 // General Tokens:
 TOK(invalid)
@@ -52,9 +69,51 @@ PUNCTUATOR(equal,   '=')
 // RootElement Keywords:
 KEYWORD(DescriptorTable)
 
+// DescriptorTable Keywords:
+KEYWORD(CBV)
+KEYWORD(SRV)
+KEYWORD(UAV)
+KEYWORD(Sampler)
+
+// General Parameter Keywords:
+KEYWORD(space)
+KEYWORD(visibility)
+KEYWORD(flags)
+
+// View Parameter Keywords:
+KEYWORD(numDescriptors)
+KEYWORD(offset)
+
 // Descriptor Range Offset Enum:
 DESCRIPTOR_RANGE_OFFSET_ENUM(DescriptorRangeOffsetAppend, "DESCRIPTOR_RANGE_OFFSET_APPEND")
 
+// Root Descriptor Flag Enums:
+ROOT_DESCRIPTOR_FLAG_ENUM(DataVolatile, "DATA_VOLATILE")
+ROOT_DESCRIPTOR_FLAG_ENUM(DataStaticWhileSetAtExecute, "DATA_STATIC_WHILE_SET_AT_EXECUTE")
+ROOT_DESCRIPTOR_FLAG_ENUM(DataStatic, "DATA_STATIC")
+
+// Descriptor Range Flag Enums:
+DESCRIPTOR_RANGE_FLAG_ENUM(DescriptorsVolatile, "DESCRIPTORS_VOLATILE", ON)
+DESCRIPTOR_RANGE_FLAG_ENUM(DataVolatile, "DATA_VOLATILE", OFF)
+DESCRIPTOR_RANGE_FLAG_ENUM(DataStaticWhileSetAtExecute, "DATA_STATIC_WHILE_SET_AT_EXECUTE", OFF)
+DESCRIPTOR_RANGE_FLAG_ENUM(DataStatic, "DATA_STATIC", OFF)
+DESCRIPTOR_RANGE_FLAG_ENUM(DescriptorsStaticKeepingBufferBoundsChecks, "DESCRIPTORS_STATIC_KEEPING_BUFFER_BOUNDS_CHECKS", ON)
+
+// Shader Visibiliy Enums:
+SHADER_VISIBILITY_ENUM(All, "SHADER_VISIBILITY_ALL")
+SHADER_VISIBILITY_ENUM(Vertex, "SHADER_VISIBILITY_VERTEX")
+SHADER_VISIBILITY_ENUM(Hull, "SHADER_VISIBILITY_HULL")
+SHADER_VISIBILITY_ENUM(Domain, "SHADER_VISIBILITY_DOMAIN")
+SHADER_VISIBILITY_ENUM(Geometry, "SHADER_VISIBILITY_GEOMETRY")
+SHADER_VISIBILITY_ENUM(Pixel, "SHADER_VISIBILITY_PIXEL")
+SHADER_VISIBILITY_ENUM(Amplification, "SHADER_VISIBILITY_AMPLIFICATION")
+SHADER_VISIBILITY_ENUM(Mesh, "SHADER_VISIBILITY_MESH")
+
+#undef SHADER_VISIBILITY_ENUM
+#undef DESCRIPTOR_RANGE_FLAG_ENUM
+#undef DESCRIPTOR_RANGE_FLAG_ENUM_OFF
+#undef DESCRIPTOR_RANGE_FLAG_ENUM_ON
+#undef ROOT_DESCRIPTOR_FLAG_ENUM
 #undef DESCRIPTOR_RANGE_OFFSET_ENUM
 #undef ENUM
 #undef KEYWORD
diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index d80ded10ba313c..57b61e43746a0b 100644
--- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -150,7 +150,26 @@ TEST_F(ParseHLSLRootSignatureTest, ValidLexAllTokensTest) {
 
     DescriptorTable
 
+    CBV SRV UAV Sampler
+    space visibility flags
+    numDescriptors offset
+
     DESCRIPTOR_RANGE_OFFSET_APPEND
+
+    DATA_VOLATILE
+    DATA_STATIC_WHILE_SET_AT_EXECUTE
+    DATA_STATIC
+    DESCRIPTORS_VOLATILE
+    DESCRIPTORS_STATIC_KEEPING_BUFFER_BOUNDS_CHECKS
+
+    shader_visibility_all
+    shader_visibility_vertex
+    shader_visibility_hull
+    shader_visibility_domain
+    shader_visibility_geometry
+    shader_visibility_pixel
+    shader_visibility_amplification
+    shader_visibility_mesh
   )cc";
 
   TrivialModuleLoader ModLoader;

>From 650d65bc3425963908d0d3c17699b13e6c5d94c7 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Tue, 28 Jan 2025 17:49:21 +0000
Subject: [PATCH 06/18] [HLSL][RootSignature] Handle an empty root signature

- Define the Parser struct
- Model RootElements as a variant of the different types
- Create a basic test case for unit testing
---
 .../clang/Parse/ParseHLSLRootSignature.h      | 26 ++++++++++++++++
 clang/lib/Parse/ParseHLSLRootSignature.cpp    | 19 ++++++++++++
 .../Parse/ParseHLSLRootSignatureTest.cpp      | 26 ++++++++++++++++
 .../llvm/Frontend/HLSL/HLSLRootSignature.h    | 30 +++++++++++++++++++
 4 files changed, 101 insertions(+)
 create mode 100644 llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h

diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index 899608bd1527ea..2e01f86f832f76 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -22,9 +22,13 @@
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/StringSwitch.h"
 
+#include "llvm/Frontend/HLSL/HLSLRootSignature.h"
+
 namespace clang {
 namespace hlsl {
 
+namespace rs = llvm::hlsl::root_signature;
+
 struct RootSignatureToken {
   enum Kind {
 #define TOK(X) X,
@@ -80,6 +84,28 @@ class RootSignatureLexer {
   }
 };
 
+class RootSignatureParser {
+public:
+  RootSignatureParser(SmallVector<rs::RootElement> &Elements,
+                      const SmallVector<RootSignatureToken> &Tokens,
+                      DiagnosticsEngine &Diags);
+
+  // Iterates over the provided tokens and constructs the in-memory
+  // representations of the RootElements.
+  //
+  // The return value denotes if there was a failure and the method will
+  // return on the first encountered failure, or, return false if it
+  // can sucessfully reach the end of the tokens.
+  bool Parse();
+
+private:
+  SmallVector<rs::RootElement> &Elements;
+  SmallVector<RootSignatureToken>::const_iterator CurTok;
+  SmallVector<RootSignatureToken>::const_iterator LastTok;
+
+  DiagnosticsEngine &Diags;
+};
+
 } // namespace hlsl
 } // namespace clang
 
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 7ceb85a47a088e..2ba756f3bd09f2 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -1,5 +1,7 @@
 #include "clang/Parse/ParseHLSLRootSignature.h"
 
+using namespace llvm::hlsl::root_signature;
+
 namespace clang {
 namespace hlsl {
 
@@ -169,5 +171,22 @@ bool RootSignatureLexer::LexToken(RootSignatureToken &Result) {
   return false;
 }
 
+// Parser Definitions
+
+RootSignatureParser::RootSignatureParser(
+    SmallVector<RootElement> &Elements,
+    const SmallVector<RootSignatureToken> &Tokens, DiagnosticsEngine &Diags)
+    : Elements(Elements), Diags(Diags) {
+  CurTok = Tokens.begin();
+  LastTok = Tokens.end();
+}
+
+bool RootSignatureParser::Parse() {
+  // Handle edge-case of empty RootSignature()
+  if (CurTok == LastTok)
+    return false;
+
+  return true;
+}
 } // namespace hlsl
 } // namespace clang
diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index 57b61e43746a0b..87bceeeb3283ef 100644
--- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -24,6 +24,7 @@
 #include "gtest/gtest.h"
 
 using namespace clang;
+using namespace llvm::hlsl::root_signature;
 
 namespace {
 
@@ -279,4 +280,29 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidLexIdentifierTest) {
   ASSERT_TRUE(Consumer->IsSatisfied());
 }
 
+// Valid Parser Tests
+
+TEST_F(ParseHLSLRootSignatureTest, ValidParseEmptyTest) {
+  const llvm::StringLiteral Source = R"cc()cc";
+
+  TrivialModuleLoader ModLoader;
+  auto PP = CreatePP(Source, ModLoader);
+  auto TokLoc = SourceLocation();
+
+  // Test no diagnostics produced
+  Consumer->SetNoDiag();
+  hlsl::RootSignatureLexer Lexer(Source, TokLoc, *PP);
+
+  SmallVector<hlsl::RootSignatureToken> Tokens;
+  ASSERT_FALSE(Lexer.Lex(Tokens));
+
+  SmallVector<RootElement> Elements;
+  hlsl::RootSignatureParser Parser(Elements, Tokens, Diags);
+
+  ASSERT_FALSE(Parser.Parse());
+  ASSERT_EQ((int)Elements.size(), 0);
+
+  ASSERT_TRUE(Consumer->IsSatisfied());
+}
+
 } // anonymous namespace
diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
new file mode 100644
index 00000000000000..4c196d29a01bbb
--- /dev/null
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
@@ -0,0 +1,30 @@
+//===- HLSLRootSignature.h - HLSL Root Signature helper objects -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file This file contains helper objects for working with HLSL Root
+/// Signatures.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_FRONTEND_HLSL_HLSLROOTSIGNATURE_H
+#define LLVM_FRONTEND_HLSL_HLSLROOTSIGNATURE_H
+
+#include <variant>
+
+namespace llvm {
+namespace hlsl {
+namespace root_signature {
+
+// Models RootElement
+using RootElement = std::variant<std::monostate>;
+
+} // namespace root_signature
+} // namespace hlsl
+} // namespace llvm
+
+#endif // LLVM_FRONTEND_HLSL_HLSLROOTSIGNATURE_H

>From 03c5dfe2fccee5a98a94136c4fb6b94d2e70615b Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Tue, 28 Jan 2025 18:04:00 +0000
Subject: [PATCH 07/18] add support for an empty descriptor table

---
 .../clang/Parse/ParseHLSLRootSignature.h      | 21 +++++
 clang/lib/Parse/ParseHLSLRootSignature.cpp    | 84 +++++++++++++++++++
 .../Parse/ParseHLSLRootSignatureTest.cpp      | 30 +++++++
 .../llvm/Frontend/HLSL/HLSLRootSignature.h    | 11 ++-
 4 files changed, 144 insertions(+), 2 deletions(-)

diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index 2e01f86f832f76..43b5535dbdbffd 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -98,6 +98,27 @@ class RootSignatureParser {
   // can sucessfully reach the end of the tokens.
   bool Parse();
 
+private:
+  // Root Element helpers
+  bool ParseRootElement(bool First);
+  bool ParseDescriptorTable();
+
+  // Increment the token iterator if we have not reached the end.
+  // Return value denotes if we were already at the last token.
+  bool ConsumeNextToken();
+
+  // Is the current token one of the expected kinds
+  bool EnsureExpectedToken(TokenKind AnyExpected);
+  bool EnsureExpectedToken(ArrayRef<TokenKind> AnyExpected);
+
+  // Consume the next token and report an error if it is not of the expected
+  // kind.
+  //
+  // Return value denotes if it failed to match the expected kind, either it is
+  // the end of the stream or it didn't match any of the expected kinds.
+  bool ConsumeExpectedToken(TokenKind Expected);
+  bool ConsumeExpectedToken(ArrayRef<TokenKind> AnyExpected);
+
 private:
   SmallVector<rs::RootElement> &Elements;
   SmallVector<RootSignatureToken>::const_iterator CurTok;
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 2ba756f3bd09f2..ff66763be9af3a 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -186,7 +186,91 @@ bool RootSignatureParser::Parse() {
   if (CurTok == LastTok)
     return false;
 
+  bool First = true;
+  // Iterate as many RootElements as possible
+  while (!ParseRootElement(First)) {
+    First = false;
+    // Avoid use of ConsumeNextToken here to skip incorrect end of tokens error
+    CurTok++;
+    if (CurTok == LastTok)
+      return false;
+    if (EnsureExpectedToken(TokenKind::pu_comma))
+      return true;
+  }
+
+  return true;
+}
+
+bool RootSignatureParser::ParseRootElement(bool First) {
+  if (First && EnsureExpectedToken(TokenKind::kw_DescriptorTable))
+    return true;
+  if (!First && ConsumeExpectedToken(TokenKind::kw_DescriptorTable))
+    return true;
+
+  // Dispatch onto the correct parse method
+  switch (CurTok->Kind) {
+  case TokenKind::kw_DescriptorTable:
+    return ParseDescriptorTable();
+  default:
+    llvm_unreachable("Switch for an expected token was not provided");
+    return true;
+  }
+}
+
+bool RootSignatureParser::ParseDescriptorTable() {
+  DescriptorTable Table;
+
+  if (ConsumeExpectedToken(TokenKind::pu_l_paren))
+    return true;
+
+  // Empty case:
+  if (!ConsumeExpectedToken(TokenKind::pu_r_paren)) {
+    Elements.push_back(Table);
+    return false;
+  }
+
+  return true;
+
+}
+
+bool RootSignatureParser::ConsumeNextToken() {
+  CurTok++;
+  if (LastTok == CurTok) {
+    return true;
+  }
+  return false;
+}
+
+// Is given token one of the expected kinds
+static bool IsExpectedToken(TokenKind Kind, ArrayRef<TokenKind> AnyExpected) {
+  for (auto Expected : AnyExpected)
+    if (Kind == Expected)
+      return true;
+  return false;
+}
+
+bool RootSignatureParser::EnsureExpectedToken(TokenKind Expected) {
+  return EnsureExpectedToken(ArrayRef{Expected});
+}
+
+bool RootSignatureParser::EnsureExpectedToken(ArrayRef<TokenKind> AnyExpected) {
+  if (IsExpectedToken(CurTok->Kind, AnyExpected))
+    return false;
+
   return true;
 }
+
+bool RootSignatureParser::ConsumeExpectedToken(TokenKind Expected) {
+  return ConsumeExpectedToken(ArrayRef{Expected});
+}
+
+bool RootSignatureParser::ConsumeExpectedToken(
+    ArrayRef<TokenKind> AnyExpected) {
+  if (ConsumeNextToken())
+    return true;
+
+  return EnsureExpectedToken(AnyExpected);
+}
+
 } // namespace hlsl
 } // namespace clang
diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index 87bceeeb3283ef..9e3b8d5e888490 100644
--- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -305,4 +305,34 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseEmptyTest) {
   ASSERT_TRUE(Consumer->IsSatisfied());
 }
 
+TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
+  const llvm::StringLiteral Source = R"cc(
+    DescriptorTable()
+  )cc";
+
+  TrivialModuleLoader ModLoader;
+  auto PP = CreatePP(Source, ModLoader);
+  auto TokLoc = SourceLocation();
+
+  hlsl::RootSignatureLexer Lexer(Source, TokLoc, *PP);
+
+  SmallVector<hlsl::RootSignatureToken> Tokens;
+
+  // Test no diagnostics produced
+  Consumer->SetNoDiag();
+  ASSERT_FALSE(Lexer.Lex(Tokens));
+
+  SmallVector<RootElement> Elements;
+  hlsl::RootSignatureParser Parser(Elements, Tokens, Diags);
+
+  ASSERT_FALSE(Parser.Parse());
+  RootElement Elem = Elements[0];
+
+  // Test generated DescriptorTable start has correct default values
+  ASSERT_TRUE(std::holds_alternative<DescriptorTable>(Elem));
+  ASSERT_EQ(std::get<DescriptorTable>(Elem).NumClauses, (uint32_t)0);
+
+  ASSERT_TRUE(Consumer->IsSatisfied());
+}
+
 } // anonymous namespace
diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
index 4c196d29a01bbb..8f844d672fd554 100644
--- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
@@ -20,8 +20,15 @@ namespace llvm {
 namespace hlsl {
 namespace root_signature {
 
-// Models RootElement
-using RootElement = std::variant<std::monostate>;
+// Definitions of the in-memory data layout structures
+
+// Models the end of a descriptor table and stores its visibility
+struct DescriptorTable {
+  uint32_t NumClauses = 0; // The number of clauses in the table
+};
+
+// Models RootElement : DescriptorTable
+using RootElement = std::variant<DescriptorTable>;
 
 } // namespace root_signature
 } // namespace hlsl

>From 123b4ac3404c4dff65ee4849550fbcb2a2b1e2fe Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Tue, 28 Jan 2025 18:16:43 +0000
Subject: [PATCH 08/18] add diagnostics to methods

---
 .../clang/Basic/DiagnosticParseKinds.td       |  3 +
 .../clang/Parse/ParseHLSLRootSignature.h      |  1 +
 clang/lib/Parse/ParseHLSLRootSignature.cpp    | 55 ++++++++++++++++++-
 .../Parse/ParseHLSLRootSignatureTest.cpp      | 50 +++++++++++++++++
 4 files changed, 108 insertions(+), 1 deletion(-)

diff --git a/clang/include/clang/Basic/DiagnosticParseKinds.td b/clang/include/clang/Basic/DiagnosticParseKinds.td
index 86fcae209c40db..52c3fcce14076a 100644
--- a/clang/include/clang/Basic/DiagnosticParseKinds.td
+++ b/clang/include/clang/Basic/DiagnosticParseKinds.td
@@ -1806,4 +1806,7 @@ def ext_hlsl_access_specifiers : ExtWarn<
 def err_hlsl_unsupported_component : Error<"invalid component '%0' used; expected 'x', 'y', 'z', or 'w'">;
 def err_hlsl_packoffset_invalid_reg : Error<"invalid resource class specifier '%0' for packoffset, expected 'c'">;
 
+// HLSL Root Signature Parser Diagnostics
+def err_hlsl_rootsig_unexpected_eos : Error<"unexpected end to token stream">;
+def err_hlsl_rootsig_unexpected_token_kind : Error<"expected the %select{following|one of the following}0 token kinds '%1'">;
 } // end of Parser diagnostics
diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index 43b5535dbdbffd..88656ae6582ae0 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -15,6 +15,7 @@
 
 #include "clang/AST/APValue.h"
 #include "clang/Basic/DiagnosticLex.h"
+#include "clang/Basic/DiagnosticParse.h"
 #include "clang/Lex/LiteralSupport.h"
 #include "clang/Lex/Preprocessor.h"
 
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index ff66763be9af3a..cca08361f5d27d 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -1,10 +1,59 @@
 #include "clang/Parse/ParseHLSLRootSignature.h"
 
+#include "llvm/Support/raw_ostream.h"
+
 using namespace llvm::hlsl::root_signature;
 
 namespace clang {
 namespace hlsl {
 
+// Helper definitions
+
+static std::string FormatTokenKinds(ArrayRef<TokenKind> Kinds) {
+  std::string TokenString;
+  llvm::raw_string_ostream Out(TokenString);
+  bool First = true;
+  for (auto Kind : Kinds) {
+    if (!First)
+      Out << ", ";
+    switch (Kind) {
+    case TokenKind::invalid:
+      break;
+    case TokenKind::int_literal:
+      Out << "integer literal";
+      break;
+    case TokenKind::bReg:
+      Out << "b register";
+      break;
+    case TokenKind::tReg:
+      Out << "t register";
+      break;
+    case TokenKind::uReg:
+      Out << "u register";
+      break;
+    case TokenKind::sReg:
+      Out << "s register";
+      break;
+#define PUNCTUATOR(X, Y)                                                       \
+  case TokenKind::pu_##X:                                                      \
+    Out << #Y;                                                                 \
+    break;
+#define KEYWORD(NAME)                                                          \
+  case TokenKind::kw_##NAME:                                                   \
+    Out << #NAME;                                                              \
+    break;
+#define ENUM(NAME, LIT)                                                        \
+  case TokenKind::en_##NAME:                                                   \
+    Out << LIT;                                                                \
+    break;
+#include "clang/Parse/HLSLRootSignatureTokenKinds.def"
+    }
+    First = false;
+  }
+
+  return TokenString;
+}
+
 // Lexer Definitions
 
 static bool IsNumberChar(char C) {
@@ -230,12 +279,13 @@ bool RootSignatureParser::ParseDescriptorTable() {
   }
 
   return true;
-
 }
 
 bool RootSignatureParser::ConsumeNextToken() {
   CurTok++;
   if (LastTok == CurTok) {
+    // Report unexpected end of tokens error
+    Diags.Report(CurTok->TokLoc, diag::err_hlsl_rootsig_unexpected_eos);
     return true;
   }
   return false;
@@ -257,6 +307,9 @@ bool RootSignatureParser::EnsureExpectedToken(ArrayRef<TokenKind> AnyExpected) {
   if (IsExpectedToken(CurTok->Kind, AnyExpected))
     return false;
 
+  // Report unexpected token kind error
+  Diags.Report(CurTok->TokLoc, diag::err_hlsl_rootsig_unexpected_token_kind)
+      << (unsigned)(AnyExpected.size() != 1) << FormatTokenKinds(AnyExpected);
   return true;
 }
 
diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index 9e3b8d5e888490..790ee33c22ab86 100644
--- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -331,6 +331,56 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
   // Test generated DescriptorTable start has correct default values
   ASSERT_TRUE(std::holds_alternative<DescriptorTable>(Elem));
   ASSERT_EQ(std::get<DescriptorTable>(Elem).NumClauses, (uint32_t)0);
+  ASSERT_TRUE(Consumer->IsSatisfied());
+}
+
+// Invalid Parser Tests
+
+TEST_F(ParseHLSLRootSignatureTest, InvalidParseUnexpectedEOSTest) {
+  const llvm::StringLiteral Source = R"cc(
+    DescriptorTable
+  )cc";
+
+  TrivialModuleLoader ModLoader;
+  auto PP = CreatePP(Source, ModLoader);
+  auto TokLoc = SourceLocation();
+
+  // Test correct diagnostic produced
+  Consumer->SetExpected(diag::err_hlsl_rootsig_unexpected_eos);
+  hlsl::RootSignatureLexer Lexer(Source, TokLoc, *PP);
+
+  SmallVector<hlsl::RootSignatureToken> Tokens;
+  ASSERT_FALSE(Lexer.Lex(Tokens));
+
+  SmallVector<RootElement> Elements;
+  hlsl::RootSignatureParser Parser(Elements, Tokens, Diags);
+
+  ASSERT_TRUE(Parser.Parse());
+
+  ASSERT_TRUE(Consumer->IsSatisfied());
+}
+
+TEST_F(ParseHLSLRootSignatureTest, InvalidParseUnexpectedTokenTest) {
+  const llvm::StringLiteral Source = R"cc(
+    DescriptorTable()
+    space
+  )cc";
+
+  TrivialModuleLoader ModLoader;
+  auto PP = CreatePP(Source, ModLoader);
+  auto TokLoc = SourceLocation();
+
+  // Test correct diagnostic produced
+  Consumer->SetExpected(diag::err_hlsl_rootsig_unexpected_token_kind);
+  hlsl::RootSignatureLexer Lexer(Source, TokLoc, *PP);
+
+  SmallVector<hlsl::RootSignatureToken> Tokens;
+  ASSERT_FALSE(Lexer.Lex(Tokens));
+
+  SmallVector<RootElement> Elements;
+  hlsl::RootSignatureParser Parser(Elements, Tokens, Diags);
+
+  ASSERT_TRUE(Parser.Parse());
 
   ASSERT_TRUE(Consumer->IsSatisfied());
 }

>From 6f57e87caecf28498c671a1861e9f9ad2875ef38 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Tue, 28 Jan 2025 18:37:42 +0000
Subject: [PATCH 09/18] add support for empty descriptor table clause

---
 .../clang/Parse/ParseHLSLRootSignature.h      | 22 +++++
 clang/lib/Parse/ParseHLSLRootSignature.cpp    | 80 ++++++++++++++++++-
 .../Parse/ParseHLSLRootSignatureTest.cpp      | 25 ++++++
 .../llvm/Frontend/HLSL/HLSLRootSignature.h    | 11 ++-
 4 files changed, 134 insertions(+), 4 deletions(-)

diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index 88656ae6582ae0..d367b388582ea1 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -103,15 +103,28 @@ class RootSignatureParser {
   // Root Element helpers
   bool ParseRootElement(bool First);
   bool ParseDescriptorTable();
+  bool ParseDescriptorTableClause();
 
+  // Helper dispatch method
   // Increment the token iterator if we have not reached the end.
   // Return value denotes if we were already at the last token.
   bool ConsumeNextToken();
 
+  // Attempt to retrieve the next token, if TokenKind is invalid then there was
+  // no next token.
+  RootSignatureToken PeekNextToken();
+
   // Is the current token one of the expected kinds
   bool EnsureExpectedToken(TokenKind AnyExpected);
   bool EnsureExpectedToken(ArrayRef<TokenKind> AnyExpected);
 
+  // Peek if the next token is of the expected kind.
+  //
+  // Return value denotes if it failed to match the expected kind, either it is
+  // the end of the stream or it didn't match any of the expected kinds.
+  bool PeekExpectedToken(TokenKind Expected);
+  bool PeekExpectedToken(ArrayRef<TokenKind> AnyExpected);
+
   // Consume the next token and report an error if it is not of the expected
   // kind.
   //
@@ -120,6 +133,15 @@ class RootSignatureParser {
   bool ConsumeExpectedToken(TokenKind Expected);
   bool ConsumeExpectedToken(ArrayRef<TokenKind> AnyExpected);
 
+  // Peek if the next token is of the expected kind and if it is then consume
+  // it.
+  //
+  // Return value denotes if it failed to match the expected kind, either it is
+  // the end of the stream or it didn't match any of the expected kinds. It will
+  // not report an error if there isn't a match.
+  bool TryConsumeExpectedToken(TokenKind Expected);
+  bool TryConsumeExpectedToken(ArrayRef<TokenKind> Expected);
+
 private:
   SmallVector<rs::RootElement> &Elements;
   SmallVector<RootSignatureToken>::const_iterator CurTok;
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index cca08361f5d27d..41282580024c52 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -273,12 +273,64 @@ bool RootSignatureParser::ParseDescriptorTable() {
     return true;
 
   // Empty case:
-  if (!ConsumeExpectedToken(TokenKind::pu_r_paren)) {
+  if (!TryConsumeExpectedToken(TokenKind::pu_r_paren)) {
     Elements.push_back(Table);
     return false;
   }
 
-  return true;
+  // Iterate through all the defined clauses
+  do {
+    if (ParseDescriptorTableClause())
+      return true;
+    Table.NumClauses++;
+  } while (!TryConsumeExpectedToken(TokenKind::pu_comma));
+
+  if (ConsumeExpectedToken(TokenKind::pu_r_paren))
+    return true;
+
+  Elements.push_back(Table);
+  return false;
+}
+
+bool RootSignatureParser::ParseDescriptorTableClause() {
+  if (ConsumeExpectedToken({TokenKind::kw_CBV, TokenKind::kw_SRV,
+                            TokenKind::kw_UAV, TokenKind::kw_Sampler}))
+    return true;
+
+  DescriptorTableClause Clause;
+  switch (CurTok->Kind) {
+  case TokenKind::kw_CBV:
+    Clause.Type = ClauseType::CBuffer;
+    break;
+  case TokenKind::kw_SRV:
+    Clause.Type = ClauseType::SRV;
+    break;
+  case TokenKind::kw_UAV:
+    Clause.Type = ClauseType::UAV;
+    break;
+  case TokenKind::kw_Sampler:
+    Clause.Type = ClauseType::Sampler;
+    break;
+  default:
+    llvm_unreachable("Switch for an expected token was not provided");
+    return true;
+  }
+  if (ConsumeExpectedToken(TokenKind::pu_l_paren))
+    return true;
+
+  if (ConsumeExpectedToken(TokenKind::pu_r_paren))
+    return true;
+
+  Elements.push_back(Clause);
+  return false;
+}
+
+RootSignatureToken RootSignatureParser::PeekNextToken() {
+  // Create an invalid token
+  RootSignatureToken Token = RootSignatureToken(SourceLocation());
+  if (CurTok != LastTok)
+    Token = *(CurTok + 1);
+  return Token;
 }
 
 bool RootSignatureParser::ConsumeNextToken() {
@@ -313,6 +365,19 @@ bool RootSignatureParser::EnsureExpectedToken(ArrayRef<TokenKind> AnyExpected) {
   return true;
 }
 
+bool RootSignatureParser::PeekExpectedToken(TokenKind Expected) {
+  return PeekExpectedToken(ArrayRef{Expected});
+}
+
+bool RootSignatureParser::PeekExpectedToken(ArrayRef<TokenKind> AnyExpected) {
+  RootSignatureToken Token = PeekNextToken();
+  if (Token.Kind == TokenKind::invalid)
+    return true;
+  if (IsExpectedToken(Token.Kind, AnyExpected))
+    return false;
+  return true;
+}
+
 bool RootSignatureParser::ConsumeExpectedToken(TokenKind Expected) {
   return ConsumeExpectedToken(ArrayRef{Expected});
 }
@@ -325,5 +390,16 @@ bool RootSignatureParser::ConsumeExpectedToken(
   return EnsureExpectedToken(AnyExpected);
 }
 
+bool RootSignatureParser::TryConsumeExpectedToken(TokenKind Expected) {
+  return TryConsumeExpectedToken(ArrayRef{Expected});
+}
+
+bool RootSignatureParser::TryConsumeExpectedToken(
+    ArrayRef<TokenKind> AnyExpected) {
+  if (PeekExpectedToken(AnyExpected))
+    return true;
+  return ConsumeNextToken();
+}
+
 } // namespace hlsl
 } // namespace clang
diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index 790ee33c22ab86..ff905813fcaac8 100644
--- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -307,6 +307,12 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseEmptyTest) {
 
 TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
   const llvm::StringLiteral Source = R"cc(
+    DescriptorTable(
+      CBV(),
+      SRV(),
+      Sampler(),
+      UAV()
+    ),
     DescriptorTable()
   )cc";
 
@@ -327,7 +333,26 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
 
   ASSERT_FALSE(Parser.Parse());
   RootElement Elem = Elements[0];
+  ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::CBuffer);
 
+  Elem = Elements[1];
+  ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::SRV);
+
+  Elem = Elements[2];
+  ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::Sampler);
+
+  Elem = Elements[3];
+  ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::UAV);
+
+  Elem = Elements[4];
+  ASSERT_TRUE(std::holds_alternative<DescriptorTable>(Elem));
+  ASSERT_EQ(std::get<DescriptorTable>(Elem).NumClauses, (uint32_t)4);
+
+  Elem = Elements[5];
   // Test generated DescriptorTable start has correct default values
   ASSERT_TRUE(std::holds_alternative<DescriptorTable>(Elem));
   ASSERT_EQ(std::get<DescriptorTable>(Elem).NumClauses, (uint32_t)0);
diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
index 8f844d672fd554..9b7916107c1f17 100644
--- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
@@ -14,6 +14,7 @@
 #ifndef LLVM_FRONTEND_HLSL_HLSLROOTSIGNATURE_H
 #define LLVM_FRONTEND_HLSL_HLSLROOTSIGNATURE_H
 
+#include "llvm/Support/DXILABI.h"
 #include <variant>
 
 namespace llvm {
@@ -27,8 +28,14 @@ struct DescriptorTable {
   uint32_t NumClauses = 0; // The number of clauses in the table
 };
 
-// Models RootElement : DescriptorTable
-using RootElement = std::variant<DescriptorTable>;
+// Models DTClause : CBV | SRV | UAV | Sampler, by collecting like parameters
+using ClauseType = llvm::dxil::ResourceClass;
+struct DescriptorTableClause {
+  ClauseType Type;
+};
+
+// Models RootElement : DescriptorTable | DescriptorTableClause
+using RootElement = std::variant<DescriptorTable, DescriptorTableClause>;
 
 } // namespace root_signature
 } // namespace hlsl

>From 56664e29fbbfdda6fb666ea7103cb7f422da3016 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Tue, 28 Jan 2025 18:52:04 +0000
Subject: [PATCH 10/18] add support for parsing registers

---
 .../clang/Parse/ParseHLSLRootSignature.h      |  2 ++
 clang/lib/Parse/ParseHLSLRootSignature.cpp    | 31 +++++++++++++++++++
 .../Parse/ParseHLSLRootSignatureTest.cpp      | 23 +++++++++++---
 .../llvm/Frontend/HLSL/HLSLRootSignature.h    |  8 +++++
 4 files changed, 60 insertions(+), 4 deletions(-)

diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index d367b388582ea1..e66fad7f18f18c 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -106,6 +106,8 @@ class RootSignatureParser {
   bool ParseDescriptorTableClause();
 
   // Helper dispatch method
+  bool ParseRegister(rs::Register *Reg);
+
   // Increment the token iterator if we have not reached the end.
   // Return value denotes if we were already at the last token.
   bool ConsumeNextToken();
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 41282580024c52..203f0bdddb04b0 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -318,6 +318,13 @@ bool RootSignatureParser::ParseDescriptorTableClause() {
   if (ConsumeExpectedToken(TokenKind::pu_l_paren))
     return true;
 
+  // Consume mandatory Register paramater
+  if (ConsumeExpectedToken(
+          {TokenKind::bReg, TokenKind::tReg, TokenKind::uReg, TokenKind::sReg}))
+    return true;
+  if (ParseRegister(&Clause.Register))
+    return true;
+
   if (ConsumeExpectedToken(TokenKind::pu_r_paren))
     return true;
 
@@ -325,6 +332,30 @@ bool RootSignatureParser::ParseDescriptorTableClause() {
   return false;
 }
 
+bool RootSignatureParser::ParseRegister(Register *Register) {
+  switch (CurTok->Kind) {
+  case TokenKind::bReg:
+    Register->ViewType = RegisterType::BReg;
+    break;
+  case TokenKind::tReg:
+    Register->ViewType = RegisterType::TReg;
+    break;
+  case TokenKind::uReg:
+    Register->ViewType = RegisterType::UReg;
+    break;
+  case TokenKind::sReg:
+    Register->ViewType = RegisterType::SReg;
+    break;
+  default:
+    llvm_unreachable("Switch for an expected token was not provided");
+    return true;
+  }
+
+  Register->Number = CurTok->NumLiteral.getInt().getExtValue();
+
+  return false;
+}
+
 RootSignatureToken RootSignatureParser::PeekNextToken() {
   // Create an invalid token
   RootSignatureToken Token = RootSignatureToken(SourceLocation());
diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index ff905813fcaac8..4788589eea2d24 100644
--- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -308,10 +308,10 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseEmptyTest) {
 TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
   const llvm::StringLiteral Source = R"cc(
     DescriptorTable(
-      CBV(),
-      SRV(),
-      Sampler(),
-      UAV()
+      CBV(b0),
+      SRV(t42),
+      Sampler(s987),
+      UAV(u987234)
     ),
     DescriptorTable()
   )cc";
@@ -335,18 +335,33 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
   RootElement Elem = Elements[0];
   ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::CBuffer);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.ViewType,
+            RegisterType::BReg);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number, (uint32_t)0);
 
   Elem = Elements[1];
   ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::SRV);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.ViewType,
+            RegisterType::TReg);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number,
+            (uint32_t)42);
 
   Elem = Elements[2];
   ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::Sampler);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.ViewType,
+            RegisterType::SReg);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number,
+            (uint32_t)987);
 
   Elem = Elements[3];
   ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::UAV);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.ViewType,
+            RegisterType::UReg);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number,
+            (uint32_t)987234);
 
   Elem = Elements[4];
   ASSERT_TRUE(std::holds_alternative<DescriptorTable>(Elem));
diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
index 9b7916107c1f17..7779fd2e803ac6 100644
--- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
@@ -23,6 +23,13 @@ namespace root_signature {
 
 // Definitions of the in-memory data layout structures
 
+// Models the different registers: bReg | tReg | uReg | sReg
+enum class RegisterType { BReg, TReg, UReg, SReg };
+struct Register {
+  RegisterType ViewType;
+  uint32_t Number;
+};
+
 // Models the end of a descriptor table and stores its visibility
 struct DescriptorTable {
   uint32_t NumClauses = 0; // The number of clauses in the table
@@ -32,6 +39,7 @@ struct DescriptorTable {
 using ClauseType = llvm::dxil::ResourceClass;
 struct DescriptorTableClause {
   ClauseType Type;
+  Register Register;
 };
 
 // Models RootElement : DescriptorTable | DescriptorTableClause

>From bcc4357b4cc7f774864794637fc6818a17a7f60d Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Tue, 28 Jan 2025 19:00:29 +0000
Subject: [PATCH 11/18] add support for optional parameters

- use numDescriptors as an example
---
 .../clang/Parse/ParseHLSLRootSignature.h      | 15 +++++
 clang/lib/Parse/ParseHLSLRootSignature.cpp    | 58 +++++++++++++++++++
 .../Parse/ParseHLSLRootSignatureTest.cpp      |  6 +-
 .../llvm/Frontend/HLSL/HLSLRootSignature.h    |  4 ++
 4 files changed, 82 insertions(+), 1 deletion(-)

diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index e66fad7f18f18c..83addc0d5b092e 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -106,7 +106,22 @@ class RootSignatureParser {
   bool ParseDescriptorTableClause();
 
   // Helper dispatch method
+  //
+  // These will switch on the Variant kind to dispatch to the respective Parse
+  // method and store the parsed value back into Ref.
+  //
+  // It is helpful to have a generalized dispatch method so that when we need
+  // to parse multiple optional parameters in any order, we can invoke this
+  // method
+  bool ParseParam(rs::ParamType Ref);
+
+  // Parse as many optional parameters as possible in any order
+  bool
+  ParseOptionalParams(llvm::SmallDenseMap<TokenKind, rs::ParamType> RefMap);
+
+  // Common parsing helpers
   bool ParseRegister(rs::Register *Reg);
+  bool ParseUInt(uint32_t *X);
 
   // Increment the token iterator if we have not reached the end.
   // Return value denotes if we were already at the last token.
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 203f0bdddb04b0..6deca82bc32430 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -325,6 +325,13 @@ bool RootSignatureParser::ParseDescriptorTableClause() {
   if (ParseRegister(&Clause.Register))
     return true;
 
+  // Parse optional paramaters
+  llvm::SmallDenseMap<TokenKind, rs::ParamType> RefMap = {
+      {TokenKind::kw_numDescriptors, &Clause.NumDescriptors},
+  };
+  if (ParseOptionalParams({RefMap}))
+    return true;
+
   if (ConsumeExpectedToken(TokenKind::pu_r_paren))
     return true;
 
@@ -332,6 +339,57 @@ bool RootSignatureParser::ParseDescriptorTableClause() {
   return false;
 }
 
+// Helper struct so that we can use the overloaded notation of std::visit
+template <class... Ts> struct OverloadedMethods : Ts... {
+  using Ts::operator()...;
+};
+template <class... Ts> OverloadedMethods(Ts...) -> OverloadedMethods<Ts...>;
+
+bool RootSignatureParser::ParseParam(ParamType Ref) {
+  if (ConsumeExpectedToken(TokenKind::pu_equal))
+    return true;
+
+  bool Error;
+  std::visit(OverloadedMethods{[&](uint32_t *X) { Error = ParseUInt(X); },
+  }, Ref);
+
+  return Error;
+}
+
+bool RootSignatureParser::ParseOptionalParams(
+    llvm::SmallDenseMap<TokenKind, rs::ParamType> RefMap) {
+  SmallVector<TokenKind> ParamKeywords;
+  for (auto RefPair : RefMap)
+    ParamKeywords.push_back(RefPair.first);
+
+  // Keep track of which keywords have been seen to report duplicates
+  llvm::SmallDenseSet<TokenKind> Seen;
+
+  while (!TryConsumeExpectedToken(TokenKind::pu_comma)) {
+    if (ConsumeExpectedToken(ParamKeywords))
+      return true;
+
+    TokenKind ParamKind = CurTok->Kind;
+    if (Seen.contains(ParamKind)) {
+      return true;
+    }
+    Seen.insert(ParamKind);
+
+    if (ParseParam(RefMap[ParamKind]))
+      return true;
+  }
+
+  return false;
+}
+
+bool RootSignatureParser::ParseUInt(uint32_t *X) {
+  if (ConsumeExpectedToken(TokenKind::int_literal))
+    return true;
+
+  *X = CurTok->NumLiteral.getInt().getExtValue();
+  return false;
+}
+
 bool RootSignatureParser::ParseRegister(Register *Register) {
   switch (CurTok->Kind) {
   case TokenKind::bReg:
diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index 4788589eea2d24..e21cf4e24f591b 100644
--- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -309,7 +309,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
   const llvm::StringLiteral Source = R"cc(
     DescriptorTable(
       CBV(b0),
-      SRV(t42),
+      SRV(t42, numDescriptors = 4),
       Sampler(s987),
       UAV(u987234)
     ),
@@ -338,6 +338,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.ViewType,
             RegisterType::BReg);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number, (uint32_t)0);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).NumDescriptors, (uint32_t)1);
 
   Elem = Elements[1];
   ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -346,6 +347,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
             RegisterType::TReg);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number,
             (uint32_t)42);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).NumDescriptors, (uint32_t)4);
 
   Elem = Elements[2];
   ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -354,6 +356,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
             RegisterType::SReg);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number,
             (uint32_t)987);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).NumDescriptors, (uint32_t)1);
 
   Elem = Elements[3];
   ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -362,6 +365,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
             RegisterType::UReg);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number,
             (uint32_t)987234);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).NumDescriptors, (uint32_t)1);
 
   Elem = Elements[4];
   ASSERT_TRUE(std::holds_alternative<DescriptorTable>(Elem));
diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
index 7779fd2e803ac6..8f5655681c7b7c 100644
--- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
@@ -40,11 +40,15 @@ using ClauseType = llvm::dxil::ResourceClass;
 struct DescriptorTableClause {
   ClauseType Type;
   Register Register;
+  uint32_t NumDescriptors = 1;
 };
 
 // Models RootElement : DescriptorTable | DescriptorTableClause
 using RootElement = std::variant<DescriptorTable, DescriptorTableClause>;
 
+// Models a reference to all assignment parameter types that any RootElement
+// may have. Things of the form: Keyword = Param
+using ParamType = std::variant<uint32_t *>;
 } // namespace root_signature
 } // namespace hlsl
 } // namespace llvm

>From bf655d414ce3d309f730e72206d6a2cb4ee213f8 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Tue, 28 Jan 2025 19:19:33 +0000
Subject: [PATCH 12/18] add diagnostic for repeated parametr

---
 .../clang/Basic/DiagnosticParseKinds.td       |  1 +
 clang/lib/Parse/ParseHLSLRootSignature.cpp    |  2 ++
 .../Parse/ParseHLSLRootSignatureTest.cpp      | 26 +++++++++++++++++++
 3 files changed, 29 insertions(+)

diff --git a/clang/include/clang/Basic/DiagnosticParseKinds.td b/clang/include/clang/Basic/DiagnosticParseKinds.td
index 52c3fcce14076a..4994d2339ce578 100644
--- a/clang/include/clang/Basic/DiagnosticParseKinds.td
+++ b/clang/include/clang/Basic/DiagnosticParseKinds.td
@@ -1809,4 +1809,5 @@ def err_hlsl_packoffset_invalid_reg : Error<"invalid resource class specifier '%
 // HLSL Root Signature Parser Diagnostics
 def err_hlsl_rootsig_unexpected_eos : Error<"unexpected end to token stream">;
 def err_hlsl_rootsig_unexpected_token_kind : Error<"expected the %select{following|one of the following}0 token kinds '%1'">;
+def err_hlsl_rootsig_repeat_param : Error<"specified the same parameter '%0' multiple times">;
 } // end of Parser diagnostics
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 6deca82bc32430..e4322cc973815c 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -371,6 +371,8 @@ bool RootSignatureParser::ParseOptionalParams(
 
     TokenKind ParamKind = CurTok->Kind;
     if (Seen.contains(ParamKind)) {
+      Diags.Report(CurTok->TokLoc, diag::err_hlsl_rootsig_repeat_param)
+          << FormatTokenKinds(ParamKind);
       return true;
     }
     Seen.insert(ParamKind);
diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index e21cf4e24f591b..27b663b26f4bf1 100644
--- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -429,4 +429,30 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidParseUnexpectedTokenTest) {
   ASSERT_TRUE(Consumer->IsSatisfied());
 }
 
+TEST_F(ParseHLSLRootSignatureTest, InvalidParseRepeatedParamTest) {
+  const llvm::StringLiteral Source = R"cc(
+    DescriptorTable(
+      CBV(b0, numDescriptors = 3, numDescriptors = 1)
+    )
+  )cc";
+
+  TrivialModuleLoader ModLoader;
+  auto PP = CreatePP(Source, ModLoader);
+  auto TokLoc = SourceLocation();
+
+  // Test correct diagnostic produced
+  Consumer->SetExpected(diag::err_hlsl_rootsig_repeat_param);
+  hlsl::RootSignatureLexer Lexer(Source, TokLoc, *PP);
+
+  SmallVector<hlsl::RootSignatureToken> Tokens;
+  ASSERT_FALSE(Lexer.Lex(Tokens));
+
+  SmallVector<RootElement> Elements;
+  hlsl::RootSignatureParser Parser(Elements, Tokens, Diags);
+
+  ASSERT_TRUE(Parser.Parse());
+
+  ASSERT_TRUE(Consumer->IsSatisfied());
+}
+
 } // anonymous namespace

>From dc050e5e663c6da521fef4f4e5a99456e752396a Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Tue, 28 Jan 2025 19:23:50 +0000
Subject: [PATCH 13/18] add space optional parameter

- demonstrate can specify in any order
---
 clang/lib/Parse/ParseHLSLRootSignature.cpp           | 1 +
 clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp | 8 ++++++--
 llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h  | 1 +
 3 files changed, 8 insertions(+), 2 deletions(-)

diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index e4322cc973815c..1af9b519bc5e41 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -328,6 +328,7 @@ bool RootSignatureParser::ParseDescriptorTableClause() {
   // Parse optional paramaters
   llvm::SmallDenseMap<TokenKind, rs::ParamType> RefMap = {
       {TokenKind::kw_numDescriptors, &Clause.NumDescriptors},
+      {TokenKind::kw_space, &Clause.Space},
   };
   if (ParseOptionalParams({RefMap}))
     return true;
diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index 27b663b26f4bf1..49b4777800824f 100644
--- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -309,8 +309,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
   const llvm::StringLiteral Source = R"cc(
     DescriptorTable(
       CBV(b0),
-      SRV(t42, numDescriptors = 4),
-      Sampler(s987),
+      SRV(t42, space = 3, numDescriptors = 4),
+      Sampler(s987, space = 2),
       UAV(u987234)
     ),
     DescriptorTable()
@@ -339,6 +339,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
             RegisterType::BReg);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number, (uint32_t)0);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).NumDescriptors, (uint32_t)1);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, (uint32_t)0);
 
   Elem = Elements[1];
   ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -348,6 +349,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number,
             (uint32_t)42);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).NumDescriptors, (uint32_t)4);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, (uint32_t)3);
 
   Elem = Elements[2];
   ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -357,6 +359,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number,
             (uint32_t)987);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).NumDescriptors, (uint32_t)1);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, (uint32_t)2);
 
   Elem = Elements[3];
   ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -366,6 +369,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number,
             (uint32_t)987234);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).NumDescriptors, (uint32_t)1);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, (uint32_t)0);
 
   Elem = Elements[4];
   ASSERT_TRUE(std::holds_alternative<DescriptorTable>(Elem));
diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
index 8f5655681c7b7c..608e79494fa649 100644
--- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
@@ -41,6 +41,7 @@ struct DescriptorTableClause {
   ClauseType Type;
   Register Register;
   uint32_t NumDescriptors = 1;
+  uint32_t Space = 0;
 };
 
 // Models RootElement : DescriptorTable | DescriptorTableClause

>From b8d8e43517b8b72d78eb83776457893d5ebf258b Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Tue, 28 Jan 2025 19:29:21 +0000
Subject: [PATCH 14/18] add support for custom parameter parsing -
 DescriptorRangeOffset

---
 .../clang/Parse/ParseHLSLRootSignature.h      |  1 +
 clang/lib/Parse/ParseHLSLRootSignature.cpp    | 19 +++++++++++++++++++
 .../Parse/ParseHLSLRootSignatureTest.cpp      | 12 ++++++++++--
 .../llvm/Frontend/HLSL/HLSLRootSignature.h    |  9 ++++++++-
 4 files changed, 38 insertions(+), 3 deletions(-)

diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index 83addc0d5b092e..2e05e71445eb39 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -122,6 +122,7 @@ class RootSignatureParser {
   // Common parsing helpers
   bool ParseRegister(rs::Register *Reg);
   bool ParseUInt(uint32_t *X);
+  bool ParseDescriptorRangeOffset(rs::DescriptorRangeOffset *X);
 
   // Increment the token iterator if we have not reached the end.
   // Return value denotes if we were already at the last token.
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 1af9b519bc5e41..095c9b7a3e69d3 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -329,6 +329,7 @@ bool RootSignatureParser::ParseDescriptorTableClause() {
   llvm::SmallDenseMap<TokenKind, rs::ParamType> RefMap = {
       {TokenKind::kw_numDescriptors, &Clause.NumDescriptors},
       {TokenKind::kw_space, &Clause.Space},
+      {TokenKind::kw_offset, &Clause.Offset},
   };
   if (ParseOptionalParams({RefMap}))
     return true;
@@ -352,6 +353,9 @@ bool RootSignatureParser::ParseParam(ParamType Ref) {
 
   bool Error;
   std::visit(OverloadedMethods{[&](uint32_t *X) { Error = ParseUInt(X); },
+                               [&](DescriptorRangeOffset *X) {
+                                 Error = ParseDescriptorRangeOffset(X);
+                               },
   }, Ref);
 
   return Error;
@@ -385,6 +389,21 @@ bool RootSignatureParser::ParseOptionalParams(
   return false;
 }
 
+bool RootSignatureParser::ParseDescriptorRangeOffset(DescriptorRangeOffset *X) {
+  if (ConsumeExpectedToken(
+          {TokenKind::int_literal, TokenKind::en_DescriptorRangeOffsetAppend}))
+    return true;
+
+  // Edge case for the offset enum -> static value
+  if (CurTok->Kind == TokenKind::en_DescriptorRangeOffsetAppend) {
+    *X = DescriptorTableOffsetAppend;
+    return false;
+  }
+
+  *X = DescriptorRangeOffset(CurTok->NumLiteral.getInt().getExtValue());
+  return false;
+}
+
 bool RootSignatureParser::ParseUInt(uint32_t *X) {
   if (ConsumeExpectedToken(TokenKind::int_literal))
     return true;
diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index 49b4777800824f..003f09f353338e 100644
--- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -309,8 +309,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
   const llvm::StringLiteral Source = R"cc(
     DescriptorTable(
       CBV(b0),
-      SRV(t42, space = 3, numDescriptors = 4),
-      Sampler(s987, space = 2),
+      SRV(t42, space = 3, offset = 32, numDescriptors = 4),
+      Sampler(s987, space = 2, offset = DESCRIPTOR_RANGE_OFFSET_APPEND),
       UAV(u987234)
     ),
     DescriptorTable()
@@ -340,6 +340,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number, (uint32_t)0);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).NumDescriptors, (uint32_t)1);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, (uint32_t)0);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Offset,
+            DescriptorRangeOffset(DescriptorTableOffsetAppend));
 
   Elem = Elements[1];
   ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -350,6 +352,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
             (uint32_t)42);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).NumDescriptors, (uint32_t)4);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, (uint32_t)3);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Offset,
+            DescriptorRangeOffset(32));
 
   Elem = Elements[2];
   ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -360,6 +364,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
             (uint32_t)987);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).NumDescriptors, (uint32_t)1);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, (uint32_t)2);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Offset,
+            DescriptorRangeOffset(DescriptorTableOffsetAppend));
 
   Elem = Elements[3];
   ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -370,6 +376,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
             (uint32_t)987234);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).NumDescriptors, (uint32_t)1);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, (uint32_t)0);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Offset,
+            DescriptorRangeOffset(DescriptorTableOffsetAppend));
 
   Elem = Elements[4];
   ASSERT_TRUE(std::holds_alternative<DescriptorTable>(Elem));
diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
index 608e79494fa649..fc0fa160eebd94 100644
--- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
@@ -21,6 +21,10 @@ namespace llvm {
 namespace hlsl {
 namespace root_signature {
 
+// Definition of the various enumerations and flags
+
+enum class DescriptorRangeOffset : uint32_t;
+
 // Definitions of the in-memory data layout structures
 
 // Models the different registers: bReg | tReg | uReg | sReg
@@ -35,6 +39,8 @@ struct DescriptorTable {
   uint32_t NumClauses = 0; // The number of clauses in the table
 };
 
+static const DescriptorRangeOffset DescriptorTableOffsetAppend =
+    DescriptorRangeOffset(0xffffffff);
 // Models DTClause : CBV | SRV | UAV | Sampler, by collecting like parameters
 using ClauseType = llvm::dxil::ResourceClass;
 struct DescriptorTableClause {
@@ -42,6 +48,7 @@ struct DescriptorTableClause {
   Register Register;
   uint32_t NumDescriptors = 1;
   uint32_t Space = 0;
+  DescriptorRangeOffset Offset = DescriptorTableOffsetAppend;
 };
 
 // Models RootElement : DescriptorTable | DescriptorTableClause
@@ -49,7 +56,7 @@ using RootElement = std::variant<DescriptorTable, DescriptorTableClause>;
 
 // Models a reference to all assignment parameter types that any RootElement
 // may have. Things of the form: Keyword = Param
-using ParamType = std::variant<uint32_t *>;
+using ParamType = std::variant<uint32_t *, DescriptorRangeOffset *>;
 } // namespace root_signature
 } // namespace hlsl
 } // namespace llvm

>From 68dda01346650effac035b57187c74a23eb8cf9a Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Tue, 28 Jan 2025 19:38:57 +0000
Subject: [PATCH 15/18] add support for shader visibility

- introduces the ParseEnum function that will parse any of the ENUM
token definitions
---
 .../clang/Parse/ParseHLSLRootSignature.h      |  6 +++
 clang/lib/Parse/ParseHLSLRootSignature.cpp    | 51 +++++++++++++++++++
 .../Parse/ParseHLSLRootSignatureTest.cpp      | 32 ++++++++++++
 .../llvm/Frontend/HLSL/HLSLRootSignature.h    | 15 +++++-
 4 files changed, 103 insertions(+), 1 deletion(-)

diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index 2e05e71445eb39..4fc04ca72bf08c 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -124,6 +124,12 @@ class RootSignatureParser {
   bool ParseUInt(uint32_t *X);
   bool ParseDescriptorRangeOffset(rs::DescriptorRangeOffset *X);
 
+  // Various flags/enum parsing helpers
+  template <typename EnumType>
+  bool ParseEnum(llvm::SmallDenseMap<TokenKind, EnumType> EnumMap,
+                 EnumType *Enum);
+  bool ParseShaderVisibility(rs::ShaderVisibility *Enum);
+
   // Increment the token iterator if we have not reached the end.
   // Return value denotes if we were already at the last token.
   bool ConsumeNextToken();
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 095c9b7a3e69d3..e758c4b95abbb4 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -278,8 +278,23 @@ bool RootSignatureParser::ParseDescriptorTable() {
     return false;
   }
 
+  bool SeenVisibility = false;
   // Iterate through all the defined clauses
   do {
+    // Handle the visibility parameter
+    if (!TryConsumeExpectedToken(TokenKind::kw_visibility)) {
+      if (SeenVisibility) {
+        Diags.Report(CurTok->TokLoc, diag::err_hlsl_rootsig_repeat_param)
+            << FormatTokenKinds(CurTok->Kind);
+        return true;
+      }
+      SeenVisibility = true;
+      if (ParseParam(&Table.Visibility))
+        return true;
+      continue;
+    }
+
+    // Otherwise, we expect a clause
     if (ParseDescriptorTableClause())
       return true;
     Table.NumClauses++;
@@ -356,6 +371,9 @@ bool RootSignatureParser::ParseParam(ParamType Ref) {
                                [&](DescriptorRangeOffset *X) {
                                  Error = ParseDescriptorRangeOffset(X);
                                },
+                               [&](ShaderVisibility *Enum) {
+                                 Error = ParseShaderVisibility(Enum);
+                               },
   }, Ref);
 
   return Error;
@@ -436,6 +454,39 @@ bool RootSignatureParser::ParseRegister(Register *Register) {
   return false;
 }
 
+template <typename EnumType>
+bool RootSignatureParser::ParseEnum(
+    llvm::SmallDenseMap<TokenKind, EnumType> EnumMap, EnumType *Enum) {
+  SmallVector<TokenKind> EnumToks;
+  for (auto EnumPair : EnumMap)
+    EnumToks.push_back(EnumPair.first);
+
+  // If invoked we expect to have an enum
+  if (ConsumeExpectedToken(EnumToks))
+    return true;
+
+  // Effectively a switch statement on the token kinds
+  for (auto EnumPair : EnumMap)
+    if (CurTok->Kind == EnumPair.first) {
+      *Enum = EnumPair.second;
+      return false;
+    }
+
+  llvm_unreachable("Switch for an expected token was not provided");
+  return true;
+}
+
+bool RootSignatureParser::ParseShaderVisibility(ShaderVisibility *Enum) {
+  // Define the possible flag kinds
+  llvm::SmallDenseMap<TokenKind, ShaderVisibility> EnumMap = {
+#define SHADER_VISIBILITY_ENUM(NAME, LIT)                                      \
+  {TokenKind::en_##NAME, ShaderVisibility::NAME},
+#include "clang/Parse/HLSLRootSignatureTokenKinds.def"
+  };
+
+  return ParseEnum(EnumMap, Enum);
+}
+
 RootSignatureToken RootSignatureParser::PeekNextToken() {
   // Create an invalid token
   RootSignatureToken Token = RootSignatureToken(SourceLocation());
diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index 003f09f353338e..1ee00b0db7746a 100644
--- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -308,6 +308,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseEmptyTest) {
 TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
   const llvm::StringLiteral Source = R"cc(
     DescriptorTable(
+      visibility = SHADER_VISIBILITY_PIXEL,
       CBV(b0),
       SRV(t42, space = 3, offset = 32, numDescriptors = 4),
       Sampler(s987, space = 2, offset = DESCRIPTOR_RANGE_OFFSET_APPEND),
@@ -382,11 +383,15 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
   Elem = Elements[4];
   ASSERT_TRUE(std::holds_alternative<DescriptorTable>(Elem));
   ASSERT_EQ(std::get<DescriptorTable>(Elem).NumClauses, (uint32_t)4);
+  ASSERT_EQ(std::get<DescriptorTable>(Elem).Visibility,
+            ShaderVisibility::Pixel);
 
   Elem = Elements[5];
   // Test generated DescriptorTable start has correct default values
   ASSERT_TRUE(std::holds_alternative<DescriptorTable>(Elem));
   ASSERT_EQ(std::get<DescriptorTable>(Elem).NumClauses, (uint32_t)0);
+  ASSERT_EQ(std::get<DescriptorTable>(Elem).Visibility, ShaderVisibility::All);
+
   ASSERT_TRUE(Consumer->IsSatisfied());
 }
 
@@ -467,4 +472,31 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidParseRepeatedParamTest) {
   ASSERT_TRUE(Consumer->IsSatisfied());
 }
 
+TEST_F(ParseHLSLRootSignatureTest, InvalidParseRepeatedVisibilityTest) {
+  const llvm::StringLiteral Source = R"cc(
+    DescriptorTable(
+      visibility = SHADER_VISIBILITY_GEOMETRY,
+      visibility = SHADER_VISIBILITY_HULL
+    )
+  )cc";
+
+  TrivialModuleLoader ModLoader;
+  auto PP = CreatePP(Source, ModLoader);
+  auto TokLoc = SourceLocation();
+
+  // Test correct diagnostic produced
+  Consumer->SetExpected(diag::err_hlsl_rootsig_repeat_param);
+  hlsl::RootSignatureLexer Lexer(Source, TokLoc, *PP);
+
+  SmallVector<hlsl::RootSignatureToken> Tokens;
+  ASSERT_FALSE(Lexer.Lex(Tokens));
+
+  SmallVector<RootElement> Elements;
+  hlsl::RootSignatureParser Parser(Elements, Tokens, Diags);
+
+  ASSERT_TRUE(Parser.Parse());
+
+  ASSERT_TRUE(Consumer->IsSatisfied());
+}
+
 } // anonymous namespace
diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
index fc0fa160eebd94..6b1feca48801d8 100644
--- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
@@ -25,6 +25,17 @@ namespace root_signature {
 
 enum class DescriptorRangeOffset : uint32_t;
 
+enum class ShaderVisibility {
+  All = 0,
+  Vertex = 1,
+  Hull = 2,
+  Domain = 3,
+  Geometry = 4,
+  Pixel = 5,
+  Amplification = 6,
+  Mesh = 7,
+};
+
 // Definitions of the in-memory data layout structures
 
 // Models the different registers: bReg | tReg | uReg | sReg
@@ -36,6 +47,7 @@ struct Register {
 
 // Models the end of a descriptor table and stores its visibility
 struct DescriptorTable {
+  ShaderVisibility Visibility = ShaderVisibility::All;
   uint32_t NumClauses = 0; // The number of clauses in the table
 };
 
@@ -56,7 +68,8 @@ using RootElement = std::variant<DescriptorTable, DescriptorTableClause>;
 
 // Models a reference to all assignment parameter types that any RootElement
 // may have. Things of the form: Keyword = Param
-using ParamType = std::variant<uint32_t *, DescriptorRangeOffset *>;
+using ParamType = std::variant<uint32_t *, DescriptorRangeOffset *,
+                               ShaderVisibility *>;
 } // namespace root_signature
 } // namespace hlsl
 } // namespace llvm

>From a8ccb65c58cb41ff242320a8b533e52f59cbc2d5 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Tue, 28 Jan 2025 19:46:34 +0000
Subject: [PATCH 16/18] add support for parsing Flag parameters

- use DescriptorRangeFlags to demonstrate valid functionality
---
 .../clang/Parse/ParseHLSLRootSignature.h      |  6 ++-
 clang/lib/Parse/ParseHLSLRootSignature.cpp    | 50 ++++++++++++++++-
 .../Parse/ParseHLSLRootSignatureTest.cpp      | 16 +++++-
 .../llvm/Frontend/HLSL/HLSLRootSignature.h    | 54 ++++++++++++++++++-
 4 files changed, 121 insertions(+), 5 deletions(-)

diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index 4fc04ca72bf08c..942b8bfdbc1083 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -125,9 +125,13 @@ class RootSignatureParser {
   bool ParseDescriptorRangeOffset(rs::DescriptorRangeOffset *X);
 
   // Various flags/enum parsing helpers
-  template <typename EnumType>
+  template <bool AllowZero = false, typename EnumType>
   bool ParseEnum(llvm::SmallDenseMap<TokenKind, EnumType> EnumMap,
                  EnumType *Enum);
+  template <typename FlagType>
+  bool ParseFlags(llvm::SmallDenseMap<TokenKind, FlagType> EnumMap,
+                  FlagType *Enum);
+  bool ParseDescriptorRangeFlags(rs::DescriptorRangeFlags *Enum);
   bool ParseShaderVisibility(rs::ShaderVisibility *Enum);
 
   // Increment the token iterator if we have not reached the end.
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index e758c4b95abbb4..74230b43de3650 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -330,6 +330,8 @@ bool RootSignatureParser::ParseDescriptorTableClause() {
     llvm_unreachable("Switch for an expected token was not provided");
     return true;
   }
+  Clause.SetDefaultFlags();
+
   if (ConsumeExpectedToken(TokenKind::pu_l_paren))
     return true;
 
@@ -345,6 +347,7 @@ bool RootSignatureParser::ParseDescriptorTableClause() {
       {TokenKind::kw_numDescriptors, &Clause.NumDescriptors},
       {TokenKind::kw_space, &Clause.Space},
       {TokenKind::kw_offset, &Clause.Offset},
+      {TokenKind::kw_flags, &Clause.Flags},
   };
   if (ParseOptionalParams({RefMap}))
     return true;
@@ -371,6 +374,9 @@ bool RootSignatureParser::ParseParam(ParamType Ref) {
                                [&](DescriptorRangeOffset *X) {
                                  Error = ParseDescriptorRangeOffset(X);
                                },
+                               [&](DescriptorRangeFlags *Flags) {
+                                 Error = ParseDescriptorRangeFlags(Flags);
+                               },
                                [&](ShaderVisibility *Enum) {
                                  Error = ParseShaderVisibility(Enum);
                                },
@@ -454,10 +460,12 @@ bool RootSignatureParser::ParseRegister(Register *Register) {
   return false;
 }
 
-template <typename EnumType>
+template <bool AllowZero, typename EnumType>
 bool RootSignatureParser::ParseEnum(
     llvm::SmallDenseMap<TokenKind, EnumType> EnumMap, EnumType *Enum) {
   SmallVector<TokenKind> EnumToks;
+  if (AllowZero)
+    EnumToks.push_back(TokenKind::int_literal); //  '0' is a valid flag value
   for (auto EnumPair : EnumMap)
     EnumToks.push_back(EnumPair.first);
 
@@ -465,6 +473,16 @@ bool RootSignatureParser::ParseEnum(
   if (ConsumeExpectedToken(EnumToks))
     return true;
 
+  // Handle the edge case when '0' is used to specify None
+  if (CurTok->Kind == TokenKind::int_literal) {
+    if (CurTok->NumLiteral.getInt() != 0) {
+      return true;
+    }
+    // Set enum to None equivalent
+    *Enum = EnumType(0);
+    return false;
+  }
+
   // Effectively a switch statement on the token kinds
   for (auto EnumPair : EnumMap)
     if (CurTok->Kind == EnumPair.first) {
@@ -476,6 +494,36 @@ bool RootSignatureParser::ParseEnum(
   return true;
 }
 
+template <typename FlagType>
+bool RootSignatureParser::ParseFlags(
+    llvm::SmallDenseMap<TokenKind, FlagType> FlagMap, FlagType *Flags) {
+  // Override the default value to 0 so that we can correctly 'or' the values
+  *Flags = FlagType(0);
+
+  do {
+    FlagType Flag;
+    if (ParseEnum<true>(FlagMap, &Flag))
+      return true;
+    // Store the 'or'
+    *Flags |= Flag;
+
+  } while (!TryConsumeExpectedToken(TokenKind::pu_or));
+
+  return false;
+}
+
+bool RootSignatureParser::ParseDescriptorRangeFlags(
+    DescriptorRangeFlags *Flags) {
+  // Define the possible flag kinds
+  llvm::SmallDenseMap<TokenKind, DescriptorRangeFlags> FlagMap = {
+#define DESCRIPTOR_RANGE_FLAG_ENUM(NAME, LIT, ON)                              \
+  {TokenKind::en_##NAME, DescriptorRangeFlags::NAME},
+#include "clang/Parse/HLSLRootSignatureTokenKinds.def"
+  };
+
+  return ParseFlags(FlagMap, Flags);
+}
+
 bool RootSignatureParser::ParseShaderVisibility(ShaderVisibility *Enum) {
   // Define the possible flag kinds
   llvm::SmallDenseMap<TokenKind, ShaderVisibility> EnumMap = {
diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index 1ee00b0db7746a..dd502ce6444f39 100644
--- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -310,9 +310,13 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
     DescriptorTable(
       visibility = SHADER_VISIBILITY_PIXEL,
       CBV(b0),
-      SRV(t42, space = 3, offset = 32, numDescriptors = 4),
+      SRV(t42, space = 3, offset = 32, numDescriptors = 4, flags = 0),
       Sampler(s987, space = 2, offset = DESCRIPTOR_RANGE_OFFSET_APPEND),
-      UAV(u987234)
+      UAV(u987234,
+        flags = Descriptors_Volatile | Data_Volatile
+                      | Data_Static_While_Set_At_Execute | Data_Static
+                      | Descriptors_Static_Keeping_Buffer_Bounds_Checks
+      )
     ),
     DescriptorTable()
   )cc";
@@ -343,6 +347,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, (uint32_t)0);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Offset,
             DescriptorRangeOffset(DescriptorTableOffsetAppend));
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
+            DescriptorRangeFlags::DataStaticWhileSetAtExecute);
 
   Elem = Elements[1];
   ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -355,6 +361,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, (uint32_t)3);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Offset,
             DescriptorRangeOffset(32));
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
+            DescriptorRangeFlags::None);
 
   Elem = Elements[2];
   ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -367,6 +375,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, (uint32_t)2);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Offset,
             DescriptorRangeOffset(DescriptorTableOffsetAppend));
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
+            DescriptorRangeFlags::None);
 
   Elem = Elements[3];
   ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -379,6 +389,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, (uint32_t)0);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Offset,
             DescriptorRangeOffset(DescriptorTableOffsetAppend));
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
+            DescriptorRangeFlags::ValidFlags);
 
   Elem = Elements[4];
   ASSERT_TRUE(std::holds_alternative<DescriptorTable>(Elem));
diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
index 6b1feca48801d8..511ffd96a6d16c 100644
--- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
@@ -21,10 +21,43 @@ namespace llvm {
 namespace hlsl {
 namespace root_signature {
 
+#define RS_DEFINE_ENUM_CLASS_FLAGS_OPERATORS(Class)                            \
+  inline Class operator|(Class a, Class b) {                                   \
+    return static_cast<Class>(llvm::to_underlying(a) |                         \
+                              llvm::to_underlying(b));                         \
+  }                                                                            \
+  inline Class operator&(Class a, Class b) {                                   \
+    return static_cast<Class>(llvm::to_underlying(a) &                         \
+                              llvm::to_underlying(b));                         \
+  }                                                                            \
+  inline Class operator~(Class a) {                                            \
+    return static_cast<Class>(~llvm::to_underlying(a));                        \
+  }                                                                            \
+  inline Class &operator|=(Class &a, Class b) {                                \
+    a = a | b;                                                                 \
+    return a;                                                                  \
+  }                                                                            \
+  inline Class &operator&=(Class &a, Class b) {                                \
+    a = a & b;                                                                 \
+    return a;                                                                  \
+  }
+
 // Definition of the various enumerations and flags
 
 enum class DescriptorRangeOffset : uint32_t;
 
+enum class DescriptorRangeFlags : unsigned {
+  None = 0,
+  DescriptorsVolatile = 0x1,
+  DataVolatile = 0x2,
+  DataStaticWhileSetAtExecute = 0x4,
+  DataStatic = 0x8,
+  DescriptorsStaticKeepingBufferBoundsChecks = 0x10000,
+  ValidFlags = 0x1000f,
+  ValidSamplerFlags = DescriptorsVolatile,
+};
+RS_DEFINE_ENUM_CLASS_FLAGS_OPERATORS(DescriptorRangeFlags)
+
 enum class ShaderVisibility {
   All = 0,
   Vertex = 1,
@@ -61,6 +94,24 @@ struct DescriptorTableClause {
   uint32_t NumDescriptors = 1;
   uint32_t Space = 0;
   DescriptorRangeOffset Offset = DescriptorTableOffsetAppend;
+  DescriptorRangeFlags Flags;
+
+  void SetDefaultFlags() {
+    switch (Type) {
+    case ClauseType::CBuffer:
+      Flags = DescriptorRangeFlags::DataStaticWhileSetAtExecute;
+      break;
+    case ClauseType::SRV:
+      Flags = DescriptorRangeFlags::DataStaticWhileSetAtExecute;
+      break;
+    case ClauseType::UAV:
+      Flags = DescriptorRangeFlags::DataVolatile;
+      break;
+    case ClauseType::Sampler:
+      Flags = DescriptorRangeFlags::None;
+      break;
+    }
+  }
 };
 
 // Models RootElement : DescriptorTable | DescriptorTableClause
@@ -69,7 +120,8 @@ using RootElement = std::variant<DescriptorTable, DescriptorTableClause>;
 // Models a reference to all assignment parameter types that any RootElement
 // may have. Things of the form: Keyword = Param
 using ParamType = std::variant<uint32_t *, DescriptorRangeOffset *,
-                               ShaderVisibility *>;
+                               DescriptorRangeFlags *, ShaderVisibility *>;
+
 } // namespace root_signature
 } // namespace hlsl
 } // namespace llvm

>From 91989a014fcffa4ace5c3bb8d2e72c28fab236fe Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Tue, 28 Jan 2025 19:47:44 +0000
Subject: [PATCH 17/18] add diagnostic for non-zero int literal as flag

---
 .../clang/Basic/DiagnosticParseKinds.td       |  2 ++
 clang/lib/Parse/ParseHLSLRootSignature.cpp    |  1 +
 .../Parse/ParseHLSLRootSignatureTest.cpp      | 26 +++++++++++++++++++
 3 files changed, 29 insertions(+)

diff --git a/clang/include/clang/Basic/DiagnosticParseKinds.td b/clang/include/clang/Basic/DiagnosticParseKinds.td
index 4994d2339ce578..5cb5b3b404c7a1 100644
--- a/clang/include/clang/Basic/DiagnosticParseKinds.td
+++ b/clang/include/clang/Basic/DiagnosticParseKinds.td
@@ -1810,4 +1810,6 @@ def err_hlsl_packoffset_invalid_reg : Error<"invalid resource class specifier '%
 def err_hlsl_rootsig_unexpected_eos : Error<"unexpected end to token stream">;
 def err_hlsl_rootsig_unexpected_token_kind : Error<"expected the %select{following|one of the following}0 token kinds '%1'">;
 def err_hlsl_rootsig_repeat_param : Error<"specified the same parameter '%0' multiple times">;
+def err_hlsl_rootsig_non_zero_flag : Error<"specified a non-zero integer as a flag">;
+
 } // end of Parser diagnostics
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 74230b43de3650..8611483e1c4fe0 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -476,6 +476,7 @@ bool RootSignatureParser::ParseEnum(
   // Handle the edge case when '0' is used to specify None
   if (CurTok->Kind == TokenKind::int_literal) {
     if (CurTok->NumLiteral.getInt() != 0) {
+      Diags.Report(CurTok->TokLoc, diag::err_hlsl_rootsig_non_zero_flag);
       return true;
     }
     // Set enum to None equivalent
diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index dd502ce6444f39..1375edfe3ebfe5 100644
--- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -511,4 +511,30 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidParseRepeatedVisibilityTest) {
   ASSERT_TRUE(Consumer->IsSatisfied());
 }
 
+TEST_F(ParseHLSLRootSignatureTest, InvalidParseNonZeroFlagTest) {
+  const llvm::StringLiteral Source = R"cc(
+    DescriptorTable(
+      CBV(b0, flags = 3)
+    )
+  )cc";
+
+  TrivialModuleLoader ModLoader;
+  auto PP = CreatePP(Source, ModLoader);
+  auto TokLoc = SourceLocation();
+
+  // Test correct diagnostic produced
+  Consumer->SetExpected(diag::err_hlsl_rootsig_non_zero_flag);
+  hlsl::RootSignatureLexer Lexer(Source, TokLoc, *PP);
+
+  SmallVector<hlsl::RootSignatureToken> Tokens;
+  ASSERT_FALSE(Lexer.Lex(Tokens));
+
+  SmallVector<RootElement> Elements;
+  hlsl::RootSignatureParser Parser(Elements, Tokens, Diags);
+
+  ASSERT_TRUE(Parser.Parse());
+
+  ASSERT_TRUE(Consumer->IsSatisfied());
+}
+
 } // anonymous namespace

>From 65ebfe19f44075ff3d3c4a7fdaae753d45f0c60f Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Tue, 28 Jan 2025 19:50:45 +0000
Subject: [PATCH 18/18] visit clang-format

---
 clang/lib/Parse/ParseHLSLRootSignature.cpp | 23 +++++++++++-----------
 1 file changed, 12 insertions(+), 11 deletions(-)

diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 8611483e1c4fe0..5be289906a0c9f 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -370,17 +370,18 @@ bool RootSignatureParser::ParseParam(ParamType Ref) {
     return true;
 
   bool Error;
-  std::visit(OverloadedMethods{[&](uint32_t *X) { Error = ParseUInt(X); },
-                               [&](DescriptorRangeOffset *X) {
-                                 Error = ParseDescriptorRangeOffset(X);
-                               },
-                               [&](DescriptorRangeFlags *Flags) {
-                                 Error = ParseDescriptorRangeFlags(Flags);
-                               },
-                               [&](ShaderVisibility *Enum) {
-                                 Error = ParseShaderVisibility(Enum);
-                               },
-  }, Ref);
+  std::visit(
+      OverloadedMethods{
+          [&](uint32_t *X) { Error = ParseUInt(X); },
+          [&](DescriptorRangeOffset *X) {
+            Error = ParseDescriptorRangeOffset(X);
+          },
+          [&](DescriptorRangeFlags *Flags) {
+            Error = ParseDescriptorRangeFlags(Flags);
+          },
+          [&](ShaderVisibility *Enum) { Error = ParseShaderVisibility(Enum); },
+      },
+      Ref);
 
   return Error;
 }



More information about the llvm-branch-commits mailing list