[clang] fecf074 - [HLSL][RootSignature] Add parsing of DescriptorRangeFlags (#136775)

via cfe-commits cfe-commits at lists.llvm.org
Fri Apr 25 13:05:33 PDT 2025


Author: Finn Plummer
Date: 2025-04-25T13:05:30-07:00
New Revision: fecf0742b16dc332c7a75b0a6696f08694943862

URL: https://github.com/llvm/llvm-project/commit/fecf0742b16dc332c7a75b0a6696f08694943862
DIFF: https://github.com/llvm/llvm-project/commit/fecf0742b16dc332c7a75b0a6696f08694943862.diff

LOG: [HLSL][RootSignature] Add parsing of DescriptorRangeFlags (#136775)

- Defines `parseDescriptorRangeFlags` to establish a pattern of how
flags will be parsed
- Add corresponding unit tests

Part four of implementing #126569

Added: 
    

Modified: 
    clang/include/clang/Parse/ParseHLSLRootSignature.h
    clang/lib/Parse/ParseHLSLRootSignature.cpp
    clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
    llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index d639ca91c002f..d2e8f4dbcfc0c 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -81,6 +81,7 @@ class RootSignatureParser {
   struct ParsedClauseParams {
     std::optional<llvm::hlsl::rootsig::Register> Reg;
     std::optional<uint32_t> Space;
+    std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags> Flags;
   };
   std::optional<ParsedClauseParams>
   parseDescriptorTableClauseParams(RootSignatureToken::Kind RegType);
@@ -91,11 +92,19 @@ class RootSignatureParser {
 
   /// Parsing methods of various enums
   std::optional<llvm::hlsl::rootsig::ShaderVisibility> parseShaderVisibility();
+  std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags>
+  parseDescriptorRangeFlags();
 
   /// Use NumericLiteralParser to convert CurToken.NumSpelling into a unsigned
   /// 32-bit integer
   std::optional<uint32_t> handleUIntLiteral();
 
+  /// Flags may specify the value of '0' to denote that there should be no
+  /// flags set.
+  ///
+  /// Return true if the current int_literal token is '0', otherwise false
+  bool verifyZeroFlag();
+
   /// Invoke the Lexer to consume a token and update CurToken with the result
   void consumeNextToken() { CurToken = Lexer.consumeToken(); }
 

diff  --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 8244e91c8f89a..3b9e96017c88d 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -129,6 +129,7 @@ RootSignatureParser::parseDescriptorTableClause() {
     ExpectedReg = TokenKind::sReg;
     break;
   }
+  Clause.setDefaultFlags();
 
   auto Params = parseDescriptorTableClauseParams(ExpectedReg);
   if (!Params.has_value())
@@ -147,6 +148,9 @@ RootSignatureParser::parseDescriptorTableClause() {
   if (Params->Space.has_value())
     Clause.Space = Params->Space.value();
 
+  if (Params->Flags.has_value())
+    Clause.Flags = Params->Flags.value();
+
   if (consumeExpectedToken(TokenKind::pu_r_paren,
                            diag::err_hlsl_unexpected_end_of_params,
                            /*param of=*/ParamKind))
@@ -194,6 +198,24 @@ RootSignatureParser::parseDescriptorTableClauseParams(TokenKind RegType) {
         return std::nullopt;
       Params.Space = Space;
     }
+
+    // `flags` `=` DESCRIPTOR_RANGE_FLAGS
+    if (tryConsumeExpectedToken(TokenKind::kw_flags)) {
+      if (Params.Flags.has_value()) {
+        getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
+            << CurToken.TokKind;
+        return std::nullopt;
+      }
+
+      if (consumeExpectedToken(TokenKind::pu_equal))
+        return std::nullopt;
+
+      auto Flags = parseDescriptorRangeFlags();
+      if (!Flags.has_value())
+        return std::nullopt;
+      Params.Flags = Flags;
+    }
+
   } while (tryConsumeExpectedToken(TokenKind::pu_comma));
 
   return Params;
@@ -268,6 +290,54 @@ RootSignatureParser::parseShaderVisibility() {
   return std::nullopt;
 }
 
+template <typename FlagType>
+static FlagType maybeOrFlag(std::optional<FlagType> Flags, FlagType Flag) {
+  if (!Flags.has_value())
+    return Flag;
+
+  return static_cast<FlagType>(llvm::to_underlying(Flags.value()) |
+                               llvm::to_underlying(Flag));
+}
+
+std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags>
+RootSignatureParser::parseDescriptorRangeFlags() {
+  assert(CurToken.TokKind == TokenKind::pu_equal &&
+         "Expects to only be invoked starting at given keyword");
+
+  // Handle the edge-case of '0' to specify no flags set
+  if (tryConsumeExpectedToken(TokenKind::int_literal)) {
+    if (!verifyZeroFlag()) {
+      getDiags().Report(CurToken.TokLoc, diag::err_expected) << "'0'";
+      return std::nullopt;
+    }
+    return DescriptorRangeFlags::None;
+  }
+
+  TokenKind Expected[] = {
+#define DESCRIPTOR_RANGE_FLAG_ENUM(NAME, LIT, ON) TokenKind::en_##NAME,
+#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
+  };
+
+  std::optional<DescriptorRangeFlags> Flags;
+
+  do {
+    if (tryConsumeExpectedToken(Expected)) {
+      switch (CurToken.TokKind) {
+#define DESCRIPTOR_RANGE_FLAG_ENUM(NAME, LIT, ON)                              \
+  case TokenKind::en_##NAME:                                                   \
+    Flags =                                                                    \
+        maybeOrFlag<DescriptorRangeFlags>(Flags, DescriptorRangeFlags::NAME);  \
+    break;
+#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
+      default:
+        llvm_unreachable("Switch for consumed enum token was not provided");
+      }
+    }
+  } while (tryConsumeExpectedToken(TokenKind::pu_or));
+
+  return Flags;
+}
+
 std::optional<uint32_t> RootSignatureParser::handleUIntLiteral() {
   // Parse the numeric value and do semantic checks on its specification
   clang::NumericLiteralParser Literal(CurToken.NumSpelling, CurToken.TokLoc,
@@ -290,6 +360,12 @@ std::optional<uint32_t> RootSignatureParser::handleUIntLiteral() {
   return Val.getExtValue();
 }
 
+bool RootSignatureParser::verifyZeroFlag() {
+  assert(CurToken.TokKind == TokenKind::int_literal);
+  auto X = handleUIntLiteral();
+  return X.has_value() && X.value() == 0;
+}
+
 bool RootSignatureParser::peekExpectedToken(TokenKind Expected) {
   return peekExpectedToken(ArrayRef{Expected});
 }

diff  --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index 1d89567509e72..f4baf1580de61 100644
--- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -130,10 +130,14 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
   const llvm::StringLiteral Source = R"cc(
     DescriptorTable(
       CBV(b0),
-      SRV(space = 3, t42),
+      SRV(space = 3, t42, flags = 0),
       visibility = SHADER_VISIBILITY_PIXEL,
       Sampler(s987, space = +2),
-      UAV(u4294967294)
+      UAV(u4294967294,
+        flags = Descriptors_Volatile | Data_Volatile
+                      | Data_Static_While_Set_At_Execute | Data_Static
+                      | Descriptors_Static_Keeping_Buffer_Bounds_Checks
+      )
     ),
     DescriptorTable()
   )cc";
@@ -159,6 +163,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
             RegisterType::BReg);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 0u);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 0u);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
+            DescriptorRangeFlags::DataStaticWhileSetAtExecute);
 
   Elem = Elements[1];
   ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -167,6 +173,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
             RegisterType::TReg);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 42u);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 3u);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
+            DescriptorRangeFlags::None);
 
   Elem = Elements[2];
   ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -175,6 +183,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
             RegisterType::SReg);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 987u);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 2u);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
+            DescriptorRangeFlags::None);
 
   Elem = Elements[3];
   ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -183,6 +193,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
             RegisterType::UReg);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 4294967294u);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 0u);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
+            DescriptorRangeFlags::ValidFlags);
 
   Elem = Elements[4];
   ASSERT_TRUE(std::holds_alternative<DescriptorTable>(Elem));
@@ -199,6 +211,35 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
   ASSERT_TRUE(Consumer->isSatisfied());
 }
 
+TEST_F(ParseHLSLRootSignatureTest, ValidSamplerFlagsTest) {
+  // This test will checks we can set the valid enum for Sampler descriptor
+  // range flags
+  const llvm::StringLiteral Source = R"cc(
+    DescriptorTable(Sampler(s0, flags = DESCRIPTORS_VOLATILE))
+  )cc";
+
+  TrivialModuleLoader ModLoader;
+  auto PP = createPP(Source, ModLoader);
+  auto TokLoc = SourceLocation();
+
+  hlsl::RootSignatureLexer Lexer(Source, TokLoc);
+  SmallVector<RootElement> Elements;
+  hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
+
+  // Test no diagnostics produced
+  Consumer->setNoDiag();
+
+  ASSERT_FALSE(Parser.parse());
+
+  RootElement Elem = Elements[0];
+  ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::Sampler);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
+            DescriptorRangeFlags::ValidSamplerFlags);
+
+  ASSERT_TRUE(Consumer->isSatisfied());
+}
+
 TEST_F(ParseHLSLRootSignatureTest, ValidTrailingCommaTest) {
   // This test will checks we can handling trailing commas ','
   const llvm::StringLiteral Source = R"cc(
@@ -383,4 +424,28 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidLexOverflowedNumberTest) {
   ASSERT_TRUE(Consumer->isSatisfied());
 }
 
+TEST_F(ParseHLSLRootSignatureTest, InvalidNonZeroFlagsTest) {
+  // This test will check that parsing fails when a non-zero integer literal
+  // is given to flags
+  const llvm::StringLiteral Source = R"cc(
+    DescriptorTable(
+      CBV(b0, flags = 3)
+    )
+  )cc";
+
+  TrivialModuleLoader ModLoader;
+  auto PP = createPP(Source, ModLoader);
+  auto TokLoc = SourceLocation();
+
+  hlsl::RootSignatureLexer Lexer(Source, TokLoc);
+  SmallVector<RootElement> Elements;
+  hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
+
+  // Test correct diagnostic produced
+  Consumer->setExpected(diag::err_expected);
+  ASSERT_TRUE(Parser.parse());
+
+  ASSERT_TRUE(Consumer->isSatisfied());
+}
+
 } // anonymous namespace

diff  --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
index d51b853942dd3..0745bce983bb3 100644
--- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
@@ -23,6 +23,17 @@ namespace rootsig {
 
 // Definition of the various enumerations and flags
 
+enum class DescriptorRangeFlags : unsigned {
+  None = 0,
+  DescriptorsVolatile = 0x1,
+  DataVolatile = 0x2,
+  DataStaticWhileSetAtExecute = 0x4,
+  DataStatic = 0x8,
+  DescriptorsStaticKeepingBufferBoundsChecks = 0x10000,
+  ValidFlags = 0x1000f,
+  ValidSamplerFlags = DescriptorsVolatile,
+};
+
 enum class ShaderVisibility {
   All = 0,
   Vertex = 1,
@@ -55,6 +66,22 @@ struct DescriptorTableClause {
   ClauseType Type;
   Register Reg;
   uint32_t Space = 0;
+  DescriptorRangeFlags Flags;
+
+  void setDefaultFlags() {
+    switch (Type) {
+    case ClauseType::CBuffer:
+    case ClauseType::SRV:
+      Flags = DescriptorRangeFlags::DataStaticWhileSetAtExecute;
+      break;
+    case ClauseType::UAV:
+      Flags = DescriptorRangeFlags::DataVolatile;
+      break;
+    case ClauseType::Sampler:
+      Flags = DescriptorRangeFlags::None;
+      break;
+    }
+  }
 };
 
 // Models RootElement : DescriptorTable | DescriptorTableClause


        


More information about the cfe-commits mailing list