[clang] 5fd4f32 - [HLSL] Implement SV_GroupID semantic (#115911)

via cfe-commits cfe-commits at lists.llvm.org
Tue Nov 26 10:45:35 PST 2024


Author: Zhengxing li
Date: 2024-11-26T10:45:31-08:00
New Revision: 5fd4f32f985f83414d82a1c2c55741e363693352

URL: https://github.com/llvm/llvm-project/commit/5fd4f32f985f83414d82a1c2c55741e363693352
DIFF: https://github.com/llvm/llvm-project/commit/5fd4f32f985f83414d82a1c2c55741e363693352.diff

LOG: [HLSL] Implement SV_GroupID semantic (#115911)

Support SV_GroupID attribute.
Translate it into dx.group.id in clang codeGen.

Fixes: #70120

Added: 
    clang/test/CodeGenHLSL/semantics/SV_GroupID.hlsl

Modified: 
    clang/include/clang/Basic/Attr.td
    clang/include/clang/Basic/AttrDocs.td
    clang/include/clang/Sema/SemaHLSL.h
    clang/lib/CodeGen/CGHLSLRuntime.cpp
    clang/lib/Parse/ParseHLSL.cpp
    clang/lib/Sema/SemaDeclAttr.cpp
    clang/lib/Sema/SemaHLSL.cpp
    clang/test/SemaHLSL/Semantics/entry_parameter.hlsl
    clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl
    clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index 6db36a015acfd7..b055cbd769bb50 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4621,6 +4621,13 @@ def HLSLNumThreads: InheritableAttr {
   let Documentation = [NumThreadsDocs];
 }
 
+def HLSLSV_GroupID: HLSLAnnotationAttr {
+  let Spellings = [HLSLAnnotation<"SV_GroupID">];
+  let Subjects = SubjectList<[ParmVar, Field]>;
+  let LangOpts = [HLSL];
+  let Documentation = [HLSLSV_GroupIDDocs];
+}
+
 def HLSLSV_GroupIndex: HLSLAnnotationAttr {
   let Spellings = [HLSLAnnotation<"SV_GroupIndex">];
   let Subjects = SubjectList<[ParmVar, GlobalVar]>;

diff  --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td
index cbbfedeec46cee..aafd4449e47004 100644
--- a/clang/include/clang/Basic/AttrDocs.td
+++ b/clang/include/clang/Basic/AttrDocs.td
@@ -7934,6 +7934,16 @@ randomized.
   }];
 }
 
+def HLSLSV_GroupIDDocs : Documentation {
+  let Category = DocCatFunction;
+  let Content = [{
+The ``SV_GroupID`` semantic, when applied to an input parameter, specifies which
+thread group a shader is executing in. This attribute is only supported in compute shaders.
+
+The full documentation is available here: https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sv-groupid
+  }];
+}
+
 def HLSLSV_GroupIndexDocs : Documentation {
   let Category = DocCatFunction;
   let Content = [{

diff  --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index 06c541dec08cc8..ee685d95c96154 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -119,6 +119,7 @@ class SemaHLSL : public SemaBase {
   void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL);
   void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL);
   void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL);
+  void handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL);
   void handlePackOffsetAttr(Decl *D, const ParsedAttr &AL);
   void handleShaderAttr(Decl *D, const ParsedAttr &AL);
   void handleResourceBindingAttr(Decl *D, const ParsedAttr &AL);
@@ -136,6 +137,9 @@ class SemaHLSL : public SemaBase {
 
   bool CheckCompatibleParameterABI(FunctionDecl *New, FunctionDecl *Old);
 
+  // Diagnose whether the input ID is uint/unit2/uint3 type.
+  bool diagnoseInputIDType(QualType T, const ParsedAttr &AL);
+
   ExprResult ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg);
 
   QualType getInoutParameterType(QualType Ty);

diff  --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp
index 7ba0d615018181..2c293523fca8ca 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.cpp
+++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp
@@ -389,6 +389,10 @@ llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B,
         CGM.getIntrinsic(getThreadIdIntrinsic());
     return buildVectorInput(B, ThreadIDIntrinsic, Ty);
   }
+  if (D.hasAttr<HLSLSV_GroupIDAttr>()) {
+    llvm::Function *GroupIDIntrinsic = CGM.getIntrinsic(Intrinsic::dx_group_id);
+    return buildVectorInput(B, GroupIDIntrinsic, Ty);
+  }
   assert(false && "Unhandled parameter attribute");
   return nullptr;
 }

diff  --git a/clang/lib/Parse/ParseHLSL.cpp b/clang/lib/Parse/ParseHLSL.cpp
index 46a37e94353533..4de342b63ed802 100644
--- a/clang/lib/Parse/ParseHLSL.cpp
+++ b/clang/lib/Parse/ParseHLSL.cpp
@@ -280,6 +280,7 @@ void Parser::ParseHLSLAnnotations(ParsedAttributes &Attrs,
   case ParsedAttr::UnknownAttribute:
     Diag(Loc, diag::err_unknown_hlsl_semantic) << II;
     return;
+  case ParsedAttr::AT_HLSLSV_GroupID:
   case ParsedAttr::AT_HLSLSV_GroupIndex:
   case ParsedAttr::AT_HLSLSV_DispatchThreadID:
     break;

diff  --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index 146d9c86e0715a..53cc8cb6afd7dc 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -7103,6 +7103,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
   case ParsedAttr::AT_HLSLWaveSize:
     S.HLSL().handleWaveSizeAttr(D, AL);
     break;
+  case ParsedAttr::AT_HLSLSV_GroupID:
+    S.HLSL().handleSV_GroupIDAttr(D, AL);
+    break;
   case ParsedAttr::AT_HLSLSV_GroupIndex:
     handleSimpleAttribute<HLSLSV_GroupIndexAttr>(S, D, AL);
     break;

diff  --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 8109c3a2cc0f1b..8b2f24a8e4be0a 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -434,6 +434,7 @@ void SemaHLSL::CheckSemanticAnnotation(
   switch (AnnotationAttr->getKind()) {
   case attr::HLSLSV_DispatchThreadID:
   case attr::HLSLSV_GroupIndex:
+  case attr::HLSLSV_GroupID:
     if (ST == llvm::Triple::Compute)
       return;
     DiagnoseAttrStageMismatch(AnnotationAttr, ST, {llvm::Triple::Compute});
@@ -764,26 +765,36 @@ void SemaHLSL::handleWaveSizeAttr(Decl *D, const ParsedAttr &AL) {
     D->addAttr(NewAttr);
 }
 
-static bool isLegalTypeForHLSLSV_DispatchThreadID(QualType T) {
-  if (!T->hasUnsignedIntegerRepresentation())
+bool SemaHLSL::diagnoseInputIDType(QualType T, const ParsedAttr &AL) {
+  const auto *VT = T->getAs<VectorType>();
+
+  if (!T->hasUnsignedIntegerRepresentation() ||
+      (VT && VT->getNumElements() > 3)) {
+    Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
+        << AL << "uint/uint2/uint3";
     return false;
-  if (const auto *VT = T->getAs<VectorType>())
-    return VT->getNumElements() <= 3;
+  }
+
   return true;
 }
 
 void SemaHLSL::handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL) {
   auto *VD = cast<ValueDecl>(D);
-  if (!isLegalTypeForHLSLSV_DispatchThreadID(VD->getType())) {
-    Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
-        << AL << "uint/uint2/uint3";
+  if (!diagnoseInputIDType(VD->getType(), AL))
     return;
-  }
 
   D->addAttr(::new (getASTContext())
                  HLSLSV_DispatchThreadIDAttr(getASTContext(), AL));
 }
 
+void SemaHLSL::handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL) {
+  auto *VD = cast<ValueDecl>(D);
+  if (!diagnoseInputIDType(VD->getType(), AL))
+    return;
+
+  D->addAttr(::new (getASTContext()) HLSLSV_GroupIDAttr(getASTContext(), AL));
+}
+
 void SemaHLSL::handlePackOffsetAttr(Decl *D, const ParsedAttr &AL) {
   if (!isa<VarDecl>(D) || !isa<HLSLBufferDecl>(D->getDeclContext())) {
     Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_ast_node)

diff  --git a/clang/test/CodeGenHLSL/semantics/SV_GroupID.hlsl b/clang/test/CodeGenHLSL/semantics/SV_GroupID.hlsl
new file mode 100644
index 00000000000000..5e09f0fe06d4e6
--- /dev/null
+++ b/clang/test/CodeGenHLSL/semantics/SV_GroupID.hlsl
@@ -0,0 +1,32 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s
+
+// Make sure SV_GroupID translated into dx.group.id.
+
+// CHECK:  define void @foo()
+// CHECK:  %[[#ID:]] = call i32 @llvm.dx.group.id(i32 0)
+// CHECK:  call void @{{.*}}foo{{.*}}(i32 %[[#ID]])
+[shader("compute")]
+[numthreads(8,8,1)]
+void foo(uint Idx : SV_GroupID) {}
+
+// CHECK:  define void @bar()
+// CHECK:  %[[#ID_X:]] = call i32 @llvm.dx.group.id(i32 0)
+// CHECK:  %[[#ID_X_:]] = insertelement <2 x i32> poison, i32 %[[#ID_X]], i64 0
+// CHECK:  %[[#ID_Y:]] = call i32 @llvm.dx.group.id(i32 1)
+// CHECK:  %[[#ID_XY:]] = insertelement <2 x i32> %[[#ID_X_]], i32 %[[#ID_Y]], i64 1
+// CHECK:  call void @{{.*}}bar{{.*}}(<2 x i32> %[[#ID_XY]])
+[shader("compute")]
+[numthreads(8,8,1)]
+void bar(uint2 Idx : SV_GroupID) {}
+
+// CHECK:  define void @test()
+// CHECK:  %[[#ID_X:]] = call i32 @llvm.dx.group.id(i32 0)
+// CHECK:  %[[#ID_X_:]] = insertelement <3 x i32> poison, i32 %[[#ID_X]], i64 0
+// CHECK:  %[[#ID_Y:]] = call i32 @llvm.dx.group.id(i32 1)
+// CHECK:  %[[#ID_XY:]] = insertelement <3 x i32> %[[#ID_X_]], i32 %[[#ID_Y]], i64 1
+// CHECK:  %[[#ID_Z:]] = call i32 @llvm.dx.group.id(i32 2)
+// CHECK:  %[[#ID_XYZ:]] = insertelement <3 x i32> %[[#ID_XY]], i32 %[[#ID_Z]], i64 2
+// CHECK:  call void @{{.*}}test{{.*}}(<3 x i32> %[[#ID_XYZ]])
+[shader("compute")]
+[numthreads(8,8,1)]
+void test(uint3 Idx : SV_GroupID) {}

diff  --git a/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl b/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl
index 8484259f84692b..13c07038d2e4a4 100644
--- a/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl
+++ b/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl
@@ -2,12 +2,15 @@
 // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-mesh -hlsl-entry CSMain -x hlsl -finclude-default-header  -verify -o - %s
 
 [numthreads(8,8,1)]
-// expected-error at +2 {{attribute 'SV_GroupIndex' is unsupported in 'mesh' shaders, requires compute}}
-// expected-error at +1 {{attribute 'SV_DispatchThreadID' is unsupported in 'mesh' shaders, requires compute}}
-void CSMain(int GI : SV_GroupIndex, uint ID : SV_DispatchThreadID) {
-// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain 'void (int, uint)'
+// expected-error at +3 {{attribute 'SV_GroupIndex' is unsupported in 'mesh' shaders, requires compute}}
+// expected-error at +2 {{attribute 'SV_DispatchThreadID' is unsupported in 'mesh' shaders, requires compute}}
+// expected-error at +1 {{attribute 'SV_GroupID' is unsupported in 'mesh' shaders, requires compute}}
+void CSMain(int GI : SV_GroupIndex, uint ID : SV_DispatchThreadID, uint GID : SV_GroupID) {
+// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain 'void (int, uint, uint)'
 // CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:17 GI 'int'
 // CHECK-NEXT: HLSLSV_GroupIndexAttr
 // CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:42 ID 'uint'
 // CHECK-NEXT: HLSLSV_DispatchThreadIDAttr
+// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:73 GID 'uint'
+// CHECK-NEXT: HLSLSV_GroupIDAttr
 }

diff  --git a/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl b/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl
index bc3cf8bc51daf4..4e1f88aa2294b5 100644
--- a/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl
+++ b/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl
@@ -27,3 +27,25 @@ struct ST2 {
     static uint X : SV_DispatchThreadID;
     uint s : SV_DispatchThreadID;
 };
+
+[numthreads(8,8,1)]
+// expected-error at +1 {{attribute 'SV_GroupID' only applies to a field or parameter of type 'uint/uint2/uint3'}}
+void CSMain_GID(float ID : SV_GroupID) {
+}
+
+[numthreads(8,8,1)]
+// expected-error at +1 {{attribute 'SV_GroupID' only applies to a field or parameter of type 'uint/uint2/uint3'}}
+void CSMain2_GID(ST GID : SV_GroupID) {
+
+}
+
+void foo_GID() {
+// expected-warning at +1 {{'SV_GroupID' attribute only applies to parameters and non-static data members}}
+  uint GIS : SV_GroupID;
+}
+
+struct ST2_GID {
+// expected-warning at +1 {{'SV_GroupID' attribute only applies to parameters and non-static data members}}
+    static uint GID : SV_GroupID;
+    uint s_gid : SV_GroupID;
+};

diff  --git a/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl b/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl
index 8e79fc4d85ec91..10a5e5dabac87b 100644
--- a/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl
+++ b/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl
@@ -24,3 +24,28 @@ void CSMain3(uint3 : SV_DispatchThreadID) {
 // CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:20 'uint3'
 // CHECK-NEXT: HLSLSV_DispatchThreadIDAttr
 }
+
+[numthreads(8,8,1)]
+void CSMain_GID(uint ID : SV_GroupID) {
+// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain_GID 'void (uint)'
+// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:22 ID 'uint'
+// CHECK-NEXT: HLSLSV_GroupIDAttr
+}
+[numthreads(8,8,1)]
+void CSMain1_GID(uint2 ID : SV_GroupID) {
+// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain1_GID 'void (uint2)'
+// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:24 ID 'uint2'
+// CHECK-NEXT: HLSLSV_GroupIDAttr
+}
+[numthreads(8,8,1)]
+void CSMain2_GID(uint3 ID : SV_GroupID) {
+// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain2_GID 'void (uint3)'
+// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:24 ID 'uint3'
+// CHECK-NEXT: HLSLSV_GroupIDAttr
+}
+[numthreads(8,8,1)]
+void CSMain3_GID(uint3 : SV_GroupID) {
+// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain3_GID 'void (uint3)'
+// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:24 'uint3'
+// CHECK-NEXT: HLSLSV_GroupIDAttr
+}


        


More information about the cfe-commits mailing list