[clang] [llvm] [HLSL][RootSignature] Implement parsing of `RootParamter`s (PR #121803)
Finn Plummer via llvm-commits
llvm-commits at lists.llvm.org
Mon Jan 6 09:16:27 PST 2025
https://github.com/inbelic created https://github.com/llvm/llvm-project/pull/121803
```
- Implement the ParseRootParameter methods in ParseHLSLRootSignature
- Define the in-memory represenation of the various RootParameters and adds
it to the RootElement structure
- Add testing of valid inputs to ParseHLSLRootSignatureTest.cpp
```
Part of the work for #120472
>From bee90e659e647df21d1f1e65f95fffefa57eadce Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Mon, 6 Jan 2025 16:21:33 +0000
Subject: [PATCH 1/2] [HLSL] Implement parsing of `RootFlags`
- Define the Parser class that will contain all the parsing methods in
ParseHLSLRootSignature.h
- Implement the dispatch behaviour of Parse and ParseRootElement in
ParseHLSLRootSignature.cpp
- Define the general in-memory datastructure of a RootElement that will
be a union of the various RootElement types
- Implement the ParseRootFlags methods in ParseHLSLRootSignature
- Define the in-memory represenation of the RootFlag and adds it to the
RootElement structure
- Add testing of valid inputs to ParseHLSLRootSignatureTest.cpp
---
.../clang/Sema/ParseHLSLRootSignature.h | 63 ++++++++
clang/lib/Sema/CMakeLists.txt | 1 +
clang/lib/Sema/ParseHLSLRootSignature.cpp | 139 ++++++++++++++++++
clang/unittests/Sema/CMakeLists.txt | 1 +
.../Sema/ParseHLSLRootSignatureTest.cpp | 58 ++++++++
.../llvm/Frontend/HLSL/HLSLRootSignature.h | 88 +++++++++++
6 files changed, 350 insertions(+)
create mode 100644 clang/include/clang/Sema/ParseHLSLRootSignature.h
create mode 100644 clang/lib/Sema/ParseHLSLRootSignature.cpp
create mode 100644 clang/unittests/Sema/ParseHLSLRootSignatureTest.cpp
create mode 100644 llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
diff --git a/clang/include/clang/Sema/ParseHLSLRootSignature.h b/clang/include/clang/Sema/ParseHLSLRootSignature.h
new file mode 100644
index 00000000000000..7d1799e22b515c
--- /dev/null
+++ b/clang/include/clang/Sema/ParseHLSLRootSignature.h
@@ -0,0 +1,63 @@
+//===--- ParseHLSLRootSignature.h -------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the ParseHLSLRootSignature interface.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_SEMA_PARSEHLSLROOTSIGNATURE_H
+#define LLVM_CLANG_SEMA_PARSEHLSLROOTSIGNATURE_H
+
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/StringSwitch.h"
+
+#include "llvm/Frontend/HLSL/HLSLRootSignature.h"
+
+namespace llvm {
+namespace hlsl {
+namespace root_signature {
+
+class Parser {
+public:
+ Parser(StringRef Signature, SmallVector<RootElement> *Elements)
+ : Buffer(Signature), Elements(Elements) {}
+
+ // Consumes the internal buffer as a list of root elements and will
+ // emplace their in-memory representation onto the back of Elements.
+ //
+ // It will consume until it successfully reaches the end of the buffer,
+ // or until the first error is encountered. The return value denotes if
+ // there was a failure.
+ bool Parse();
+
+private:
+ bool ReportError();
+
+ // RootElements parse methods
+ bool ParseRootElement();
+
+ bool ParseRootFlags();
+ // Enum methods
+ template <typename EnumType>
+ bool ParseEnum(SmallVector<std::pair<StringLiteral, EnumType>> Mapping,
+ EnumType &Enum);
+ bool ParseRootFlag(RootFlags &Flag);
+
+ // Internal state used when parsing
+ StringRef Buffer;
+ SmallVector<RootElement> *Elements;
+
+ StringRef Token;
+};
+
+} // namespace root_signature
+} // namespace hlsl
+} // namespace llvm
+
+#endif // LLVM_CLANG_SEMA_PARSEHLSLROOTSIGNATURE_H
diff --git a/clang/lib/Sema/CMakeLists.txt b/clang/lib/Sema/CMakeLists.txt
index 719c3a9312ec15..7141bb42eb4363 100644
--- a/clang/lib/Sema/CMakeLists.txt
+++ b/clang/lib/Sema/CMakeLists.txt
@@ -24,6 +24,7 @@ add_clang_library(clangSema
JumpDiagnostics.cpp
MultiplexExternalSemaSource.cpp
ParsedAttr.cpp
+ ParseHLSLRootSignature.cpp
Scope.cpp
ScopeInfo.cpp
Sema.cpp
diff --git a/clang/lib/Sema/ParseHLSLRootSignature.cpp b/clang/lib/Sema/ParseHLSLRootSignature.cpp
new file mode 100644
index 00000000000000..e4592ea1937178
--- /dev/null
+++ b/clang/lib/Sema/ParseHLSLRootSignature.cpp
@@ -0,0 +1,139 @@
+#include "clang/Sema/ParseHLSLRootSignature.h"
+
+namespace llvm {
+namespace hlsl {
+namespace root_signature {
+
+// TODO: Hook up with Sema to properly report semantic/validation errors
+bool Parser::ReportError() { return true; }
+
+bool Parser::ParseRootFlags() {
+ // Set to RootFlags::None and skip whitespace to catch when we have RootFlags(
+ // )
+ RootFlags Flags = RootFlags::None;
+ Buffer = Buffer.drop_while(isspace);
+ StringLiteral Prefix = "";
+
+ // Loop until we reach the end of the rootflags
+ while (!Buffer.starts_with(")")) {
+ // Trim expected | when more than 1 flag
+ if (!Buffer.consume_front(Prefix))
+ return ReportError();
+ Prefix = "|";
+
+ // Remove any whitespace
+ Buffer = Buffer.drop_while(isspace);
+
+ RootFlags CurFlag;
+ if (ParseRootFlag(CurFlag))
+ return ReportError();
+ Flags |= CurFlag;
+
+ // Remove any whitespace
+ Buffer = Buffer.drop_while(isspace);
+ }
+
+ // Create and push the root element on the parsed elements
+ Elements->push_back(RootElement(Flags));
+ return false;
+}
+
+template <typename EnumType>
+bool Parser::ParseEnum(SmallVector<std::pair<StringLiteral, EnumType>> Mapping,
+ EnumType &Enum) {
+ // Retrieve enum
+ Token = Buffer.take_while([](char C) { return isalnum(C) || C == '_'; });
+ Buffer = Buffer.drop_front(Token.size());
+
+ // Try to get the case-insensitive enum
+ auto Switch = llvm::StringSwitch<std::optional<EnumType>>(Token);
+ for (auto Pair : Mapping)
+ Switch.CaseLower(Pair.first, Pair.second);
+ auto MaybeEnum = Switch.Default(std::nullopt);
+ if (!MaybeEnum)
+ return true;
+ Enum = *MaybeEnum;
+
+ return false;
+}
+
+bool Parser::ParseRootFlag(RootFlags &Flag) {
+ SmallVector<std::pair<StringLiteral, RootFlags>> Mapping = {
+ {"0", RootFlags::None},
+ {"ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT",
+ RootFlags::AllowInputAssemblerInputLayout},
+ {"DENY_VERTEX_SHADER_ROOT_ACCESS", RootFlags::DenyVertexShaderRootAccess},
+ {"DENY_HULL_SHADER_ROOT_ACCESS", RootFlags::DenyHullShaderRootAccess},
+ {"DENY_DOMAIN_SHADER_ROOT_ACCESS", RootFlags::DenyDomainShaderRootAccess},
+ {"DENY_GEOMETRY_SHADER_ROOT_ACCESS",
+ RootFlags::DenyGeometryShaderRootAccess},
+ {"DENY_PIXEL_SHADER_ROOT_ACCESS", RootFlags::DenyPixelShaderRootAccess},
+ {"ALLOW_STREAM_OUTPUT", RootFlags::AllowStreamOutput},
+ {"LOCAL_ROOT_SIGNATURE", RootFlags::LocalRootSignature},
+ {"DENY_AMPLIFICATION_SHADER_ROOT_ACCESS",
+ RootFlags::DenyAmplificationShaderRootAccess},
+ {"DENY_MESH_SHADER_ROOT_ACCESS", RootFlags::DenyMeshShaderRootAccess},
+ {"CBV_SRV_UAV_HEAP_DIRECTLY_INDEXED",
+ RootFlags::CBVSRVUAVHeapDirectlyIndexed},
+ {"SAMPLER_HEAP_DIRECTLY_INDEXED", RootFlags::SamplerHeapDirectlyIndexed},
+ {"AllowLowTierReservedHwCbLimit",
+ RootFlags::AllowLowTierReservedHwCbLimit},
+ };
+
+ return ParseEnum<RootFlags>(Mapping, Flag);
+}
+
+bool Parser::ParseRootElement() {
+ // Define different ParserMethods to use StringSwitch for dispatch
+ enum class ParserMethod {
+ ReportError,
+ ParseRootFlags,
+ };
+
+ // Retreive which method should be used
+ auto Method = llvm::StringSwitch<ParserMethod>(Token)
+ .Case("RootFlags", ParserMethod::ParseRootFlags)
+ .Default(ParserMethod::ReportError);
+
+ // Dispatch on the correct method
+ switch (Method) {
+ case ParserMethod::ReportError:
+ return ReportError();
+ case ParserMethod::ParseRootFlags:
+ return ParseRootFlags();
+ }
+}
+
+// Parser entry point function
+bool Parser::Parse() {
+ StringLiteral Prefix = "";
+ while (!Buffer.empty()) {
+ // Trim expected comma when more than 1 root element
+ if (!Buffer.consume_front(Prefix))
+ return ReportError();
+ Prefix = ",";
+
+ // Remove any whitespace
+ Buffer = Buffer.drop_while(isspace);
+
+ // Retrieve the root element identifier
+ auto Split = Buffer.split('(');
+ Token = Split.first;
+ Buffer = Split.second;
+
+ // Dispatch to the applicable root element parser
+ if (ParseRootElement())
+ return true;
+
+ // Then we can clean up the remaining ")"
+ if (!Buffer.consume_front(")"))
+ return ReportError();
+ }
+
+ // All input has been correctly parsed
+ return false;
+}
+
+} // namespace root_signature
+} // namespace hlsl
+} // namespace llvm
diff --git a/clang/unittests/Sema/CMakeLists.txt b/clang/unittests/Sema/CMakeLists.txt
index 7ded562e8edfa5..f382f8b1235306 100644
--- a/clang/unittests/Sema/CMakeLists.txt
+++ b/clang/unittests/Sema/CMakeLists.txt
@@ -7,6 +7,7 @@ add_clang_unittest(SemaTests
ExternalSemaSourceTest.cpp
CodeCompleteTest.cpp
GslOwnerPointerInference.cpp
+ ParseHLSLRootSignatureTest.cpp
SemaLookupTest.cpp
SemaNoloadLookupTest.cpp
)
diff --git a/clang/unittests/Sema/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Sema/ParseHLSLRootSignatureTest.cpp
new file mode 100644
index 00000000000000..0e7feb50871669
--- /dev/null
+++ b/clang/unittests/Sema/ParseHLSLRootSignatureTest.cpp
@@ -0,0 +1,58 @@
+//=== ParseHLSLRootSignatureTest.cpp - Parse Root Signature tests ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/Sema/ParseHLSLRootSignature.h"
+#include "gtest/gtest.h"
+
+using namespace llvm::hlsl::root_signature;
+
+namespace {
+
+TEST(ParseHLSLRootSignature, EmptyRootFlags) {
+ llvm::StringRef RootFlagString = " RootFlags()";
+ llvm::SmallVector<RootElement> RootElements;
+ Parser Parser(RootFlagString, &RootElements);
+ ASSERT_FALSE(Parser.Parse());
+ ASSERT_EQ(RootElements.size(), (unsigned long)1);
+ ASSERT_EQ(RootFlags::None, RootElements[0].Flags);
+}
+
+TEST(ParseHLSLRootSignature, RootFlagsNone) {
+ llvm::StringRef RootFlagString = " RootFlags(0)";
+ llvm::SmallVector<RootElement> RootElements;
+ Parser Parser(RootFlagString, &RootElements);
+ ASSERT_FALSE(Parser.Parse());
+ ASSERT_EQ(RootElements.size(), (unsigned long)1);
+ ASSERT_EQ(RootFlags::None, RootElements[0].Flags);
+}
+
+TEST(ParseHLSLRootSignature, ValidRootFlags) {
+ // Test that the flags are all captured and that they are case insensitive
+ llvm::StringRef RootFlagString = " RootFlags( "
+ " ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT"
+ "| deny_vertex_shader_root_access"
+ "| DENY_HULL_SHADER_ROOT_ACCESS"
+ "| deny_domain_shader_root_access"
+ "| DENY_GEOMETRY_SHADER_ROOT_ACCESS"
+ "| deny_pixel_shader_root_access"
+ "| ALLOW_STREAM_OUTPUT"
+ "| LOCAL_ROOT_SIGNATURE"
+ "| deny_amplification_shader_root_access"
+ "| DENY_MESH_SHADER_ROOT_ACCESS"
+ "| cbv_srv_uav_heap_directly_indexed"
+ "| SAMPLER_HEAP_DIRECTLY_INDEXED"
+ "| AllowLowTierReservedHwCbLimit )";
+
+ llvm::SmallVector<RootElement> RootElements;
+ Parser Parser(RootFlagString, &RootElements);
+ ASSERT_FALSE(Parser.Parse());
+ ASSERT_EQ(RootElements.size(), (unsigned long)1);
+ ASSERT_EQ(RootFlags::ValidFlags, RootElements[0].Flags);
+}
+
+} // anonymous namespace
diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
new file mode 100644
index 00000000000000..a17ebffc7a6bf2
--- /dev/null
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
@@ -0,0 +1,88 @@
+//===- HLSLRootSignature.h - HLSL Root Signature helper objects -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file This file contains helper objects for working with HLSL Root
+/// Signatures.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_FRONTEND_HLSL_HLSLROOTSIGNATURE_H
+#define LLVM_FRONTEND_HLSL_HLSLROOTSIGNATURE_H
+
+#include <stdint.h>
+
+#include "llvm/Support/Endian.h"
+
+namespace llvm {
+namespace hlsl {
+namespace root_signature {
+
+// This is a copy from DebugInfo/CodeView/CodeView.h
+#define RS_DEFINE_ENUM_CLASS_FLAGS_OPERATORS(Class) \
+ inline Class operator|(Class a, Class b) { \
+ return static_cast<Class>(llvm::to_underlying(a) | \
+ llvm::to_underlying(b)); \
+ } \
+ inline Class operator&(Class a, Class b) { \
+ return static_cast<Class>(llvm::to_underlying(a) & \
+ llvm::to_underlying(b)); \
+ } \
+ inline Class operator~(Class a) { \
+ return static_cast<Class>(~llvm::to_underlying(a)); \
+ } \
+ inline Class &operator|=(Class &a, Class b) { \
+ a = a | b; \
+ return a; \
+ } \
+ inline Class &operator&=(Class &a, Class b) { \
+ a = a & b; \
+ return a; \
+ }
+
+// Various enumerations and flags
+
+enum class RootFlags : uint32_t {
+ None = 0,
+ AllowInputAssemblerInputLayout = 0x1,
+ DenyVertexShaderRootAccess = 0x2,
+ DenyHullShaderRootAccess = 0x4,
+ DenyDomainShaderRootAccess = 0x8,
+ DenyGeometryShaderRootAccess = 0x10,
+ DenyPixelShaderRootAccess = 0x20,
+ AllowStreamOutput = 0x40,
+ LocalRootSignature = 0x80,
+ DenyAmplificationShaderRootAccess = 0x100,
+ DenyMeshShaderRootAccess = 0x200,
+ CBVSRVUAVHeapDirectlyIndexed = 0x400,
+ SamplerHeapDirectlyIndexed = 0x800,
+ AllowLowTierReservedHwCbLimit = 0x80000000,
+ ValidFlags = 0x80000fff
+};
+RS_DEFINE_ENUM_CLASS_FLAGS_OPERATORS(RootFlags)
+
+// Define the in-memory layout structures
+
+struct RootElement {
+ enum class ElementType {
+ RootFlags,
+ };
+
+ ElementType Tag;
+ union {
+ RootFlags Flags;
+ };
+
+ // Constructors
+ RootElement(RootFlags Flags) : Tag(ElementType::RootFlags), Flags(Flags) {}
+};
+
+} // namespace root_signature
+} // namespace hlsl
+} // namespace llvm
+
+#endif // LLVM_FRONTEND_HLSL_HLSLROOTSIGNATURE_H
>From 8ed0186a9e3a9cf2af45eb102dff77fabcc4cb97 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Mon, 6 Jan 2025 17:02:07 +0000
Subject: [PATCH 2/2] [HLSL][RootSignature] Implement parsing of
`RootParamter`s
- Implement the ParseRootParameter methods in ParseHLSLRootSignature
- Define the in-memory represenation of the RootFlag and adds it to the
RootElement structure
- Add testing of valid inputs to ParseHLSLRootSignatureTest.cpp
---
.../clang/Sema/ParseHLSLRootSignature.h | 12 ++
clang/lib/Sema/ParseHLSLRootSignature.cpp | 176 ++++++++++++++++++
.../Sema/ParseHLSLRootSignatureTest.cpp | 98 ++++++++++
.../llvm/Frontend/HLSL/HLSLRootSignature.h | 45 +++++
4 files changed, 331 insertions(+)
diff --git a/clang/include/clang/Sema/ParseHLSLRootSignature.h b/clang/include/clang/Sema/ParseHLSLRootSignature.h
index 7d1799e22b515c..f06e02800e5f5f 100644
--- a/clang/include/clang/Sema/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Sema/ParseHLSLRootSignature.h
@@ -13,6 +13,7 @@
#ifndef LLVM_CLANG_SEMA_PARSEHLSLROOTSIGNATURE_H
#define LLVM_CLANG_SEMA_PARSEHLSLROOTSIGNATURE_H
+#include "llvm/ADT/APInt.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
@@ -43,11 +44,22 @@ class Parser {
bool ParseRootElement();
bool ParseRootFlags();
+ bool ParseRootParameter();
+
+ // Helper methods
+ bool ParseAssign();
+ bool ParseComma();
+ bool ParseOptComma();
+ bool ParseRegister(Register &);
+ bool ParseUnsignedInt(uint32_t &Number);
+
// Enum methods
template <typename EnumType>
bool ParseEnum(SmallVector<std::pair<StringLiteral, EnumType>> Mapping,
EnumType &Enum);
+ bool ParseRootDescriptorFlag(RootDescriptorFlags &Flag);
bool ParseRootFlag(RootFlags &Flag);
+ bool ParseVisibility(ShaderVisibility &Visibility);
// Internal state used when parsing
StringRef Buffer;
diff --git a/clang/lib/Sema/ParseHLSLRootSignature.cpp b/clang/lib/Sema/ParseHLSLRootSignature.cpp
index e4592ea1937178..72bc5b1be32625 100644
--- a/clang/lib/Sema/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Sema/ParseHLSLRootSignature.cpp
@@ -38,6 +38,148 @@ bool Parser::ParseRootFlags() {
return false;
}
+bool Parser::ParseRootParameter() {
+ RootParameter Parameter;
+ Parameter.Type = llvm::StringSwitch<RootType>(Token)
+ .Case("CBV", RootType::CBV)
+ .Case("SRV", RootType::SRV)
+ .Case("UAV", RootType::UAV)
+ .Case("RootConstants", RootType::Constants);
+ // Will never reach here as Token was just verified in dispatch
+
+ // Remove any whitespace
+ Buffer = Buffer.drop_while(isspace);
+
+ // Retreive mandatory num32BitConstant arg for RootConstants
+ if (Parameter.Type == RootType::Constants) {
+ if (!Buffer.consume_front("num32BitConstants"))
+ return ReportError();
+
+ if (ParseAssign())
+ return ReportError();
+
+ if (ParseUnsignedInt(Parameter.Num32BitConstants))
+ return ReportError();
+
+ if (ParseOptComma())
+ return ReportError();
+ }
+
+ // Retrieve mandatory register
+ if (ParseRegister(Parameter.Register))
+ return true;
+
+ if (ParseOptComma())
+ return ReportError();
+
+ // Parse common optional space arg
+ if (Buffer.consume_front("space")) {
+ if (ParseAssign())
+ return ReportError();
+
+ if (ParseUnsignedInt(Parameter.Space))
+ return ReportError();
+
+ if (ParseOptComma())
+ return ReportError();
+ }
+
+ // Parse common optional visibility arg
+ if (Buffer.consume_front("visibility")) {
+ if (ParseAssign())
+ return ReportError();
+
+ if (ParseVisibility(Parameter.Visibility))
+ return ReportError();
+
+ if (ParseOptComma())
+ return ReportError();
+ }
+
+ // Retreive optional flags arg for non-RootConstants
+ if (Parameter.Type != RootType::Constants && Buffer.consume_front("flags")) {
+ if (ParseAssign())
+ return ReportError();
+
+ if (ParseRootDescriptorFlag(Parameter.Flags))
+ return ReportError();
+
+ // Remove trailing whitespace
+ Buffer = Buffer.drop_while(isspace);
+ }
+
+ // Create and push the root element on the parsed elements
+ Elements->push_back(RootElement(Parameter));
+ return false;
+}
+
+// Helper Parser methods
+
+// Parses " = " with varying whitespace
+bool Parser::ParseAssign() {
+ Buffer = Buffer.drop_while(isspace);
+ if (!Buffer.starts_with('='))
+ return true;
+ Buffer = Buffer.drop_front();
+ Buffer = Buffer.drop_while(isspace);
+ return false;
+}
+
+// Parses ", " with varying whitespace
+bool Parser::ParseComma() {
+ if (!Buffer.starts_with(','))
+ return true;
+ Buffer = Buffer.drop_front();
+ Buffer = Buffer.drop_while(isspace);
+ return false;
+}
+
+// Parses ", " if possible. When successful we expect another parameter, and
+// return no error, otherwise we expect that we should be at the end of the
+// root element and return an error if this isn't the case
+bool Parser::ParseOptComma() {
+ if (!ParseComma())
+ return false;
+ Buffer = Buffer.drop_while(isspace);
+ return !Buffer.starts_with(')');
+}
+
+bool Parser::ParseRegister(Register &Register) {
+ // Parse expected register type ('b', 't', 'u', 's')
+ if (Buffer.empty())
+ return ReportError();
+
+ // Get type character
+ Token = Buffer.take_front();
+ Buffer = Buffer.drop_front();
+
+ auto MaybeType = llvm::StringSwitch<std::optional<RegisterType>>(Token)
+ .Case("b", RegisterType::BReg)
+ .Case("t", RegisterType::TReg)
+ .Case("u", RegisterType::UReg)
+ .Case("s", RegisterType::SReg)
+ .Default(std::nullopt);
+ if (!MaybeType)
+ return ReportError();
+ Register.ViewType = *MaybeType;
+
+ if (ParseUnsignedInt(Register.Number))
+ return ReportError();
+
+ return false;
+}
+
+// Parses "[0-9+]" as an unsigned int
+bool Parser::ParseUnsignedInt(uint32_t &Number) {
+ StringRef NumString = Buffer.take_while(isdigit);
+ APInt X = APInt(32, 0);
+ if (NumString.getAsInteger(/*radix=*/10, X))
+ return true;
+ Number = X.getZExtValue();
+ Buffer = Buffer.drop_front(NumString.size());
+ return false;
+}
+
template <typename EnumType>
bool Parser::ParseEnum(SmallVector<std::pair<StringLiteral, EnumType>> Mapping,
EnumType &Enum) {
@@ -57,6 +199,18 @@ bool Parser::ParseEnum(SmallVector<std::pair<StringLiteral, EnumType>> Mapping,
return false;
}
+bool Parser::ParseRootDescriptorFlag(RootDescriptorFlags &Flag) {
+ SmallVector<std::pair<StringLiteral, RootDescriptorFlags>> Mapping = {
+ {"0", RootDescriptorFlags::None},
+ {"DATA_VOLATILE", RootDescriptorFlags::DataVolatile},
+ {"DATA_STATIC_WHILE_SET_AT_EXECUTE",
+ RootDescriptorFlags::DataStaticWhileSetAtExecute},
+ {"DATA_STATIC", RootDescriptorFlags::DataStatic},
+ };
+
+ return ParseEnum<RootDescriptorFlags>(Mapping, Flag);
+}
+
bool Parser::ParseRootFlag(RootFlags &Flag) {
SmallVector<std::pair<StringLiteral, RootFlags>> Mapping = {
{"0", RootFlags::None},
@@ -83,16 +237,36 @@ bool Parser::ParseRootFlag(RootFlags &Flag) {
return ParseEnum<RootFlags>(Mapping, Flag);
}
+bool Parser::ParseVisibility(ShaderVisibility &Visibility) {
+ SmallVector<std::pair<StringLiteral, ShaderVisibility>> Mapping = {
+ {"SHADER_VISIBILITY_ALL", ShaderVisibility::All},
+ {"SHADER_VISIBILITY_VERTEX", ShaderVisibility::Vertex},
+ {"SHADER_VISIBILITY_HULL", ShaderVisibility::Hull},
+ {"SHADER_VISIBILITY_DOMAIN", ShaderVisibility::Domain},
+ {"SHADER_VISIBILITY_GEOMETRY", ShaderVisibility::Geometry},
+ {"SHADER_VISIBILITY_PIXEL", ShaderVisibility::Pixel},
+ {"SHADER_VISIBILITY_AMPLIFICATION", ShaderVisibility::Amplification},
+ {"SHADER_VISIBILITY_MESH", ShaderVisibility::Mesh},
+ };
+
+ return ParseEnum<ShaderVisibility>(Mapping, Visibility);
+}
+
bool Parser::ParseRootElement() {
// Define different ParserMethods to use StringSwitch for dispatch
enum class ParserMethod {
ReportError,
ParseRootFlags,
+ ParseRootParameter,
};
// Retreive which method should be used
auto Method = llvm::StringSwitch<ParserMethod>(Token)
.Case("RootFlags", ParserMethod::ParseRootFlags)
+ .Case("RootConstants", ParserMethod::ParseRootParameter)
+ .Case("CBV", ParserMethod::ParseRootParameter)
+ .Case("SRV", ParserMethod::ParseRootParameter)
+ .Case("UAV", ParserMethod::ParseRootParameter)
.Default(ParserMethod::ReportError);
// Dispatch on the correct method
@@ -101,6 +275,8 @@ bool Parser::ParseRootElement() {
return ReportError();
case ParserMethod::ParseRootFlags:
return ParseRootFlags();
+ case ParserMethod::ParseRootParameter:
+ return ParseRootParameter();
}
}
diff --git a/clang/unittests/Sema/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Sema/ParseHLSLRootSignatureTest.cpp
index 0e7feb50871669..eefe20293732ad 100644
--- a/clang/unittests/Sema/ParseHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Sema/ParseHLSLRootSignatureTest.cpp
@@ -55,4 +55,102 @@ TEST(ParseHLSLRootSignature, ValidRootFlags) {
ASSERT_EQ(RootFlags::ValidFlags, RootElements[0].Flags);
}
+TEST(ParseHLSLRootSignature, MandatoryRootConstant) {
+ llvm::StringRef RootFlagString = "RootConstants(num32BitConstants = 4, b42)";
+ llvm::SmallVector<RootElement> RootElements;
+ Parser Parser(RootFlagString, &RootElements);
+ ASSERT_FALSE(Parser.Parse());
+ ASSERT_EQ(RootElements.size(), (unsigned long)1);
+
+ RootParameter Parameter = RootElements[0].Parameter;
+ ASSERT_EQ(RootType::Constants, Parameter.Type);
+ ASSERT_EQ(RegisterType::BReg, Parameter.Register.ViewType);
+ ASSERT_EQ((uint32_t)42, Parameter.Register.Number);
+ ASSERT_EQ((uint32_t)4, Parameter.Num32BitConstants);
+ ASSERT_EQ((uint32_t)0, Parameter.Space);
+ ASSERT_EQ(ShaderVisibility::All, Parameter.Visibility);
+}
+
+TEST(ParseHLSLRootSignature, OptionalRootConstant) {
+ llvm::StringRef RootFlagString =
+ "RootConstants(num32BitConstants = 4, b42, space = 4, visibility = "
+ "SHADER_VISIBILITY_DOMAIN)";
+ llvm::SmallVector<RootElement> RootElements;
+ Parser Parser(RootFlagString, &RootElements);
+ ASSERT_FALSE(Parser.Parse());
+ ASSERT_EQ(RootElements.size(), (unsigned long)1);
+
+ RootParameter Parameter = RootElements[0].Parameter;
+ ASSERT_EQ(RootType::Constants, Parameter.Type);
+ ASSERT_EQ(RegisterType::BReg, Parameter.Register.ViewType);
+ ASSERT_EQ((uint32_t)42, Parameter.Register.Number);
+ ASSERT_EQ((uint32_t)4, Parameter.Num32BitConstants);
+ ASSERT_EQ((uint32_t)4, Parameter.Space);
+ ASSERT_EQ(ShaderVisibility::Domain, Parameter.Visibility);
+}
+
+TEST(ParseHLSLRootSignature, DefaultRootCBV) {
+ llvm::StringRef ViewsString = "CBV(b0)";
+ llvm::SmallVector<RootElement> RootElements;
+ Parser Parser(ViewsString, &RootElements);
+ ASSERT_FALSE(Parser.Parse());
+ ASSERT_EQ(RootElements.size(), (unsigned long)1);
+
+ RootParameter Parameter = RootElements[0].Parameter;
+ ASSERT_EQ(RootType::CBV, Parameter.Type);
+ ASSERT_EQ(RegisterType::BReg, Parameter.Register.ViewType);
+ ASSERT_EQ((uint32_t)0, Parameter.Register.Number);
+ ASSERT_EQ(RootDescriptorFlags::None, Parameter.Flags);
+ ASSERT_EQ((uint32_t)0, Parameter.Space);
+ ASSERT_EQ(ShaderVisibility::All, Parameter.Visibility);
+}
+
+TEST(ParseHLSLRootSignature, SampleRootCBV) {
+ llvm::StringRef ViewsString = "CBV(b982374, space = 1, flags = DATA_STATIC)";
+ llvm::SmallVector<RootElement> RootElements;
+ Parser Parser(ViewsString, &RootElements);
+ ASSERT_FALSE(Parser.Parse());
+ ASSERT_EQ(RootElements.size(), (unsigned long)1);
+
+ RootParameter Parameter = RootElements[0].Parameter;
+ ASSERT_EQ(RootType::CBV, Parameter.Type);
+ ASSERT_EQ(RegisterType::BReg, Parameter.Register.ViewType);
+ ASSERT_EQ((uint32_t)982374, Parameter.Register.Number);
+ ASSERT_EQ(RootDescriptorFlags::DataStatic, Parameter.Flags);
+ ASSERT_EQ((uint32_t)1, Parameter.Space);
+ ASSERT_EQ(ShaderVisibility::All, Parameter.Visibility);
+}
+
+TEST(ParseHLSLRootSignature, SampleRootSRV) {
+ llvm::StringRef ViewsString = "SRV(t3, visibility = SHADER_VISIBILITY_MESH, "
+ "flags = Data_Static_While_Set_At_Execute)";
+ llvm::SmallVector<RootElement> RootElements;
+ Parser Parser(ViewsString, &RootElements);
+ ASSERT_FALSE(Parser.Parse());
+ ASSERT_EQ(RootElements.size(), (unsigned long)1);
+
+ RootParameter Parameter = RootElements[0].Parameter;
+ ASSERT_EQ(RootType::SRV, Parameter.Type);
+ ASSERT_EQ(RegisterType::TReg, Parameter.Register.ViewType);
+ ASSERT_EQ((uint32_t)3, Parameter.Register.Number);
+ ASSERT_EQ(RootDescriptorFlags::DataStaticWhileSetAtExecute, Parameter.Flags);
+ ASSERT_EQ((uint32_t)0, Parameter.Space);
+ ASSERT_EQ(ShaderVisibility::Mesh, Parameter.Visibility);
+}
+
+TEST(ParseHLSLRootSignature, SampleRootUAV) {
+ llvm::StringRef ViewsString = "UAV(u0, flags = DATA_VOLATILE)";
+ llvm::SmallVector<RootElement> RootElements;
+ Parser Parser(ViewsString, &RootElements);
+ ASSERT_FALSE(Parser.Parse());
+ ASSERT_EQ(RootElements.size(), (unsigned long)1);
+
+ RootParameter Parameter = RootElements[0].Parameter;
+ ASSERT_EQ(RootType::UAV, Parameter.Type);
+ ASSERT_EQ(RegisterType::UReg, Parameter.Register.ViewType);
+ ASSERT_EQ((uint32_t)0, Parameter.Register.Number);
+ ASSERT_EQ(RootDescriptorFlags::DataVolatile, Parameter.Flags);
+ ASSERT_EQ((uint32_t)0, Parameter.Space);
+ ASSERT_EQ(ShaderVisibility::All, Parameter.Visibility);
+}
} // anonymous namespace
diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
index a17ebffc7a6bf2..61fd47bbb48ab1 100644
--- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
@@ -65,20 +65,65 @@ enum class RootFlags : uint32_t {
};
RS_DEFINE_ENUM_CLASS_FLAGS_OPERATORS(RootFlags)
+enum class RootDescriptorFlags : unsigned {
+ None = 0,
+ DataVolatile = 0x2,
+ DataStaticWhileSetAtExecute = 0x4,
+ DataStatic = 0x8,
+ ValidFlags = 0xe
+};
+RS_DEFINE_ENUM_CLASS_FLAGS_OPERATORS(RootDescriptorFlags)
+
+enum class ShaderVisibility {
+ All = 0,
+ Vertex = 1,
+ Hull = 2,
+ Domain = 3,
+ Geometry = 4,
+ Pixel = 5,
+ Amplification = 6,
+ Mesh = 7,
+};
+
// Define the in-memory layout structures
+// Models the different registers: bReg | tReg | uReg | sReg
+enum class RegisterType { BReg, TReg, UReg, SReg };
+struct Register {
+ RegisterType ViewType;
+ uint32_t Number;
+};
+
+// Models RootConstants | RootCBV | RootSRV | RootUAV collecting like
+// parameters
+enum class RootType { CBV, SRV, UAV, Constants };
+struct RootParameter {
+ RootType Type;
+ Register Register;
+ union {
+ uint32_t Num32BitConstants;
+ RootDescriptorFlags Flags = RootDescriptorFlags::None;
+ };
+ uint32_t Space = 0;
+ ShaderVisibility Visibility = ShaderVisibility::All;
+};
+
struct RootElement {
enum class ElementType {
RootFlags,
+ RootParameter,
};
ElementType Tag;
union {
RootFlags Flags;
+ RootParameter Parameter;
};
// Constructors
RootElement(RootFlags Flags) : Tag(ElementType::RootFlags), Flags(Flags) {}
+ RootElement(RootParameter Parameter)
+ : Tag(ElementType::RootParameter), Parameter(Parameter) {}
};
} // namespace root_signature
More information about the llvm-commits
mailing list