[clang] [HLSL] Add HLSLRootSignatureAttr. (PR #83630)

Xiang Li via cfe-commits cfe-commits at lists.llvm.org
Fri Mar 1 15:03:19 PST 2024


https://github.com/python3kgae updated https://github.com/llvm/llvm-project/pull/83630

>From 994a475a75ffe9c7b7db40f4d374b4a4eb63068e Mon Sep 17 00:00:00 2001
From: Xiang Li <python3kgae at outlook.com>
Date: Fri, 1 Mar 2024 17:52:00 -0500
Subject: [PATCH 1/2] [HLSL] Add HLSLRootSignatureAttr.

First PR for RootSignature support.
HLSLRootSignatureAttr is added.
It will save the original input string for rewrite and a fake variable for the parsed result.

Parsing will be in following PR.

For #55116
---
 clang/include/clang/Basic/Attr.td             |  8 +++
 clang/include/clang/Sema/Sema.h               |  3 ++
 clang/lib/Sema/SemaDecl.cpp                   |  2 +
 clang/lib/Sema/SemaDeclAttr.cpp               | 49 +++++++++++++++++++
 clang/test/AST/HLSL/root_sigature.hlsl        | 14 ++++++
 .../test/SemaHLSL/ilegal_root_sigatures.hlsl  |  7 +++
 6 files changed, 83 insertions(+)
 create mode 100644 clang/test/AST/HLSL/root_sigature.hlsl
 create mode 100644 clang/test/SemaHLSL/ilegal_root_sigatures.hlsl

diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index fa191c7378dba4..01b0829acbe0ac 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4403,6 +4403,14 @@ def HLSLParamModifier : TypeAttr {
   let Args = [DefaultBoolArgument<"MergedSpelling", /*default*/0, /*fake*/1>];
 }
 
+def HLSLRootSignature : InheritableAttr {
+  let Spellings = [Microsoft<"RootSignature">];
+  let Subjects = SubjectList<[HLSLEntry]>;
+  let LangOpts = [HLSL];
+  let Documentation = [HLSLParamQualifierDocs];
+  let Args = [StringArgument<"InputString"> , DeclArgument<Var, "RootSignatureObject", 0, /*fake*/ 1>];
+}
+
 def RandomizeLayout : InheritableAttr {
   let Spellings = [GCC<"randomize_layout">];
   let Subjects = SubjectList<[Record]>;
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index ef4b93fac95ce5..55b49f6ea67f0c 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -3862,6 +3862,9 @@ class Sema final {
   HLSLNumThreadsAttr *mergeHLSLNumThreadsAttr(Decl *D,
                                               const AttributeCommonInfo &AL,
                                               int X, int Y, int Z);
+  HLSLRootSignatureAttr *
+  mergeHLSLRootSignatureAttr(Decl *D, const AttributeCommonInfo &AL,
+                             StringRef OrigStr);
   HLSLShaderAttr *mergeHLSLShaderAttr(Decl *D, const AttributeCommonInfo &AL,
                                       HLSLShaderAttr::ShaderType ShaderType);
   HLSLParamModifierAttr *
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index 9fdd8eb236d1ee..b5e8805c5fad42 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -2958,6 +2958,8 @@ static bool mergeDeclAttribute(Sema &S, NamedDecl *D,
   else if (const auto *NT = dyn_cast<HLSLNumThreadsAttr>(Attr))
     NewAttr =
         S.mergeHLSLNumThreadsAttr(D, *NT, NT->getX(), NT->getY(), NT->getZ());
+  else if (const auto *RS = dyn_cast<HLSLRootSignatureAttr>(Attr))
+    NewAttr = S.mergeHLSLRootSignatureAttr(D, *RS, RS->getInputString());
   else if (const auto *SA = dyn_cast<HLSLShaderAttr>(Attr))
     NewAttr = S.mergeHLSLShaderAttr(D, *SA, SA->getType());
   else if (isa<SuppressAttr>(Attr))
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index c1c28a73fd79ce..775e15d29e2144 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -7154,6 +7154,52 @@ static void handleUuidAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
     D->addAttr(UA);
 }
 
+static void handleHLSLRootSignatureAttr(Sema &S, Decl *D,
+                                        const ParsedAttr &AL) {
+  StringRef OrigStrRef;
+  SourceLocation LiteralLoc;
+  if (!S.checkStringLiteralArgumentAttr(AL, 0, OrigStrRef, &LiteralLoc))
+    return;
+  HLSLRootSignatureAttr *RSA = S.mergeHLSLRootSignatureAttr(D, AL, OrigStrRef);
+  if (RSA)
+    D->addAttr(RSA);
+}
+
+HLSLRootSignatureAttr *
+Sema::mergeHLSLRootSignatureAttr(Decl *D, const AttributeCommonInfo &AL,
+                                 StringRef OrigStr) {
+  if (HLSLRootSignatureAttr *RS = D->getAttr<HLSLRootSignatureAttr>()) {
+    if (RS->getInputString() != OrigStr) {
+      Diag(RS->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
+      Diag(AL.getLoc(), diag::note_conflicting_attribute);
+    }
+    return nullptr;
+  }
+
+  // TODO: parse the OrigStr, report error if it's not valid.
+
+  FunctionDecl *FD = D->getAsFunction();
+
+  DeclContext *DC = FD->getParent();
+
+  // Create a record decl for the root signature.
+  IdentifierInfo *II = &Context.Idents.get(FD->getName().str() + ".RS");
+  RecordDecl *RD =
+      RecordDecl::Create(Context, TagDecl::TagKind::Struct, DC,
+                         SourceLocation(), SourceLocation(), II);
+  // TODO: Add fields to the record decl.
+
+  // Create a type for the root signature.
+  QualType T = Context.getRecordType(RD);
+  // Create a variable decl for the root signature.
+  VarDecl *VD = VarDecl::Create(Context, DC, SourceLocation(),
+                                SourceLocation(), II, T, nullptr, SC_None);
+
+  // TODO: Add initializers to the variable decl.
+
+  return ::new (Context) HLSLRootSignatureAttr(Context, AL, OrigStr, VD);
+}
+
 static void handleHLSLNumThreadsAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
   llvm::VersionTuple SMVersion =
       S.Context.getTargetInfo().getTriple().getOSVersion();
@@ -9645,6 +9691,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
   case ParsedAttr::AT_HLSLNumThreads:
     handleHLSLNumThreadsAttr(S, D, AL);
     break;
+  case ParsedAttr::AT_HLSLRootSignature:
+    handleHLSLRootSignatureAttr(S, D, AL);
+    break;
   case ParsedAttr::AT_HLSLSV_GroupIndex:
     handleSimpleAttribute<HLSLSV_GroupIndexAttr>(S, D, AL);
     break;
diff --git a/clang/test/AST/HLSL/root_sigature.hlsl b/clang/test/AST/HLSL/root_sigature.hlsl
new file mode 100644
index 00000000000000..79971a3cb8c0db
--- /dev/null
+++ b/clang/test/AST/HLSL/root_sigature.hlsl
@@ -0,0 +1,14 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -ast-dump -o - %s | FileCheck %s
+
+// Make sure HLSLRootSignatureAttr is created with Var 'main.RS'
+
+// CHECK: FunctionDecl 0x{{.*}} main 'void ()'
+// CHECK-NEXT:   |-CompoundStmt
+// CHECK-NEXT:   |-HLSLShaderAttr 0x{{.*}} Compute
+// CHECK-NEXT:   |-HLSLRootSignatureAttr 0x{{.*}} "" Var 0x{{.*}} 'main.RS' 'main.RS'
+// CHECK-NEXT:   `-HLSLNumThreadsAttr 0x{{.*}} 1 1 1
+
+[shader("compute")]
+[RootSignature("")]
+[numthreads(1,1,1)]
+void main() {}
diff --git a/clang/test/SemaHLSL/ilegal_root_sigatures.hlsl b/clang/test/SemaHLSL/ilegal_root_sigatures.hlsl
new file mode 100644
index 00000000000000..fbe119e20d485f
--- /dev/null
+++ b/clang/test/SemaHLSL/ilegal_root_sigatures.hlsl
@@ -0,0 +1,7 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -verify %s
+
+// expected-error at +1 {{expected string literal as argument of 'RootSignature' attribute}}
+[RootSignature(1)]
+[shader("compute")]
+[numthreads(1,1,1)]
+void main() {}

>From 816129cb619c519c1cf1d7fdb3117755f7592128 Mon Sep 17 00:00:00 2001
From: Xiang Li <python3kgae at outlook.com>
Date: Fri, 1 Mar 2024 18:02:57 -0500
Subject: [PATCH 2/2] Fix format.

---
 clang/lib/Sema/SemaDeclAttr.cpp | 9 ++++-----
 1 file changed, 4 insertions(+), 5 deletions(-)

diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index 775e15d29e2144..5a4be144bb0682 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -7184,16 +7184,15 @@ Sema::mergeHLSLRootSignatureAttr(Decl *D, const AttributeCommonInfo &AL,
 
   // Create a record decl for the root signature.
   IdentifierInfo *II = &Context.Idents.get(FD->getName().str() + ".RS");
-  RecordDecl *RD =
-      RecordDecl::Create(Context, TagDecl::TagKind::Struct, DC,
-                         SourceLocation(), SourceLocation(), II);
+  RecordDecl *RD = RecordDecl::Create(Context, TagDecl::TagKind::Struct, DC,
+                                      SourceLocation(), SourceLocation(), II);
   // TODO: Add fields to the record decl.
 
   // Create a type for the root signature.
   QualType T = Context.getRecordType(RD);
   // Create a variable decl for the root signature.
-  VarDecl *VD = VarDecl::Create(Context, DC, SourceLocation(),
-                                SourceLocation(), II, T, nullptr, SC_None);
+  VarDecl *VD = VarDecl::Create(Context, DC, SourceLocation(), SourceLocation(),
+                                II, T, nullptr, SC_None);
 
   // TODO: Add initializers to the variable decl.
 



More information about the cfe-commits mailing list