[llvm-branch-commits] [clang] [HLSL] Define the HLSLRootSignature Attr (PR #123985)

Finn Plummer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Jan 22 10:51:17 PST 2025


https://github.com/inbelic created https://github.com/llvm/llvm-project/pull/123985

- Defines HLSLRootSignature Attr in `Attr.td`
- Define and implement handleHLSLRootSignature in `SemaHLSL`
- Adds sample test case to show AST Node is generated in `RootSignatures-AST.hlsl`

This commit will "hook-up" the seperately defined RootSignature parser and invoke it to create the RootElements, then store them on the ASTContext and finally store the reference to the Elements in RootSignatureAttr.

This pr essentially resolves #119011, however we will keep it open until we can resolve the FIXME when floating-point is supported in the parser/lexer. This will be done when we add support for StaticSampler's.

>From 42932906dc811f2f6953d0902143419daa030be7 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Wed, 22 Jan 2025 17:53:59 +0000
Subject: [PATCH] [HLSL] Define the HLSLRootSignature Attr

- Defines HLSLRootSignature Attr in `Attr.td`
- Define and implement handleHLSLRootSignature in `SemaHLSL`
- Adds sample test case to show AST Node is generated in
`RootSignatures-AST.hlsl`

This commit will "hook-up" the seperately defined RootSignature parser
and invoke it to create the RootElements, then store them on the
ASTContext and finally store the reference to the Elements in
RootSignatureAttr
---
 clang/include/clang/AST/Attr.h              |  1 +
 clang/include/clang/Basic/Attr.td           | 20 ++++++++++++++
 clang/include/clang/Basic/AttrDocs.td       |  4 +++
 clang/include/clang/Sema/SemaHLSL.h         |  1 +
 clang/lib/Sema/SemaDeclAttr.cpp             |  3 +++
 clang/lib/Sema/SemaHLSL.cpp                 | 30 +++++++++++++++++++++
 clang/test/AST/HLSL/RootSignatures-AST.hlsl | 28 +++++++++++++++++++
 7 files changed, 87 insertions(+)
 create mode 100644 clang/test/AST/HLSL/RootSignatures-AST.hlsl

diff --git a/clang/include/clang/AST/Attr.h b/clang/include/clang/AST/Attr.h
index 3365ebe4d9012b..d45b8891cf1a72 100644
--- a/clang/include/clang/AST/Attr.h
+++ b/clang/include/clang/AST/Attr.h
@@ -26,6 +26,7 @@
 #include "clang/Basic/SourceLocation.h"
 #include "clang/Support/Compiler.h"
 #include "llvm/Frontend/HLSL/HLSLResource.h"
+#include "llvm/Frontend/HLSL/HLSLRootSignature.h"
 #include "llvm/Support/CodeGen.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/VersionTuple.h"
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index 52ad72eb608c31..36ae98730db031 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4643,6 +4643,26 @@ def Error : InheritableAttr {
   let Documentation = [ErrorAttrDocs];
 }
 
+/// HLSL Root Signature Attribute
+def HLSLRootSignature : Attr {
+  /// [RootSignature(Signature)]
+  let Spellings = [Microsoft<"RootSignature">];
+  let Args = [StringArgument<"Signature">];
+  let Subjects = SubjectList<[Function],
+                             ErrorDiag, "'function'">;
+  let LangOpts = [HLSL];
+  let Documentation = [HLSLRootSignatureDocs];
+  let AdditionalMembers = [{
+private:
+  ArrayRef<llvm::hlsl::root_signature::RootElement> RootElements;
+public:
+  void setElements(ArrayRef<llvm::hlsl::root_signature::RootElement> Elements) {
+    RootElements = Elements;
+  }
+  auto getElements() const { return RootElements; }
+}];
+}
+
 def HLSLNumThreads: InheritableAttr {
   let Spellings = [Microsoft<"numthreads">];
   let Args = [IntArgument<"X">, IntArgument<"Y">, IntArgument<"Z">];
diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td
index fdad4c9a3ea191..bb0934a11f9f3f 100644
--- a/clang/include/clang/Basic/AttrDocs.td
+++ b/clang/include/clang/Basic/AttrDocs.td
@@ -7783,6 +7783,10 @@ and https://microsoft.github.io/hlsl-specs/proposals/0013-wave-size-range.html
   }];
 }
 
+def HLSLRootSignatureDocs : Documentation {
+  let Category = DocCatUndocumented;
+}
+
 def NumThreadsDocs : Documentation {
   let Category = DocCatFunction;
   let Content = [{
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index f4cd11f423a84a..df4a5c8d88ba9e 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -116,6 +116,7 @@ class SemaHLSL : public SemaBase {
                                        bool IsCompAssign);
   void emitLogicalOperatorFixIt(Expr *LHS, Expr *RHS, BinaryOperatorKind Opc);
 
+  void handleRootSignatureAttr(Decl *D, const ParsedAttr &AL);
   void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL);
   void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL);
   void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL);
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index bb4d33560b93b8..c594d6e54ddbcd 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -7149,6 +7149,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
     break;
 
   // HLSL attributes:
+  case ParsedAttr::AT_HLSLRootSignature:
+    S.HLSL().handleRootSignatureAttr(D, AL);
+    break;
   case ParsedAttr::AT_HLSLNumThreads:
     S.HLSL().handleNumThreadsAttr(D, AL);
     break;
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 600c800029fd05..5dd1e872a2b6f7 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -24,6 +24,7 @@
 #include "clang/Basic/LLVM.h"
 #include "clang/Basic/SourceLocation.h"
 #include "clang/Basic/TargetInfo.h"
+#include "clang/Parse/ParseHLSLRootSignature.h"
 #include "clang/Sema/Initialization.h"
 #include "clang/Sema/ParsedAttr.h"
 #include "clang/Sema/Sema.h"
@@ -647,6 +648,35 @@ void SemaHLSL::emitLogicalOperatorFixIt(Expr *LHS, Expr *RHS,
       << NewFnName << FixItHint::CreateReplacement(FullRange, OS.str());
 }
 
+void SemaHLSL::handleRootSignatureAttr(Decl *D, const ParsedAttr &AL) {
+  using namespace llvm::hlsl::root_signature;
+
+  if (AL.getNumArgs() != 1) return;
+
+  StringRef Signature;
+  if (!SemaRef.checkStringLiteralArgumentAttr(AL, 0, Signature))
+   return;
+
+  SourceLocation Loc = AL.getArgAsExpr(0)->getExprLoc();
+  // FIXME: pass down below to lexer when fp is supported
+  // llvm::RoundingMode RM = SemaRef.CurFPFeatures.getRoundingMode();
+  SmallVector<RootSignatureToken> Tokens;
+  RootSignatureLexer Lexer(Signature, Loc, SemaRef.getPreprocessor());
+  if (Lexer.Lex(Tokens)) return;
+
+  SmallVector<RootElement> Elements;
+  RootSignatureParser Parser(Elements, Tokens);
+  if (Parser.Parse()) return;
+
+  auto *RootElements =
+    ::new (getASTContext()) ArrayRef<RootElement>(Elements);
+
+  auto *Result =
+    ::new (getASTContext()) HLSLRootSignatureAttr(getASTContext(), AL, Signature);
+  Result->setElements(*RootElements);
+  D->addAttr(Result);
+}
+
 void SemaHLSL::handleNumThreadsAttr(Decl *D, const ParsedAttr &AL) {
   llvm::VersionTuple SMVersion =
       getASTContext().getTargetInfo().getTriple().getOSVersion();
diff --git a/clang/test/AST/HLSL/RootSignatures-AST.hlsl b/clang/test/AST/HLSL/RootSignatures-AST.hlsl
new file mode 100644
index 00000000000000..dbc06b61cffebe
--- /dev/null
+++ b/clang/test/AST/HLSL/RootSignatures-AST.hlsl
@@ -0,0 +1,28 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -ast-dump \
+// RUN:  -disable-llvm-passes -o - %s | FileCheck %s
+
+// This test ensures that the sample root signature is parsed without error and
+// the Attr AST Node is created succesfully. If an invalid root signature was
+// passed in then we would exit out of Sema before the Attr is created.
+
+#define SampleRS \
+  "DescriptorTable( " \
+  "  CBV(b1), " \
+  "  SRV(t1, numDescriptors = 8, " \
+  "          flags = DESCRIPTORS_VOLATILE), " \
+  "  UAV(u1, numDescriptors = 0, " \
+  "          flags = DESCRIPTORS_VOLATILE) " \
+  "), " \
+  "DescriptorTable(Sampler(s0, numDescriptors = 4, space = 1))"
+
+// CHECK:      HLSLRootSignatureAttr 0x{{[0-9A-Fa-f]+}} <line:{{[0-9]+}}:{{[0-9]+}}, col:{{[0-9]+}}>
+// CHECK-SAME: "DescriptorTable(
+// CHECK-SAME:   CBV(b1),
+// CHECK-SAME:   SRV(t1, numDescriptors = 8,
+// CHECK-SAME:           flags = DESCRIPTORS_VOLATILE),
+// CHECK-SAME:   UAV(u1, numDescriptors = 0,
+// CHECK-SAME:           flags = DESCRIPTORS_VOLATILE)
+// CHECK-SAME: ),
+// CHECK-SAME: DescriptorTable(Sampler(s0, numDescriptors = 4, space = 1))"
+[RootSignature(SampleRS)]
+void main() {}



More information about the llvm-branch-commits mailing list