[clang] [llvm] [HLSL] Implement SV_GroupThreadId semantic (PR #117781)

Zhengxing li via cfe-commits cfe-commits at lists.llvm.org
Fri Dec 6 09:41:18 PST 2024


https://github.com/lizhengxing updated https://github.com/llvm/llvm-project/pull/117781

>From 2941d87dbaf091aa443ad57ce55e98e7bab83d2b Mon Sep 17 00:00:00 2001
From: Zhengxing Li <zhengxingli at microsoft.com>
Date: Wed, 13 Nov 2024 10:54:16 -0800
Subject: [PATCH 1/3] [HLSL] Implement SV_GroupThreadId semantic

Support SV_GroupThreadId attribute.
Translate it into dx.thread.id.in.group in clang codeGen.

Fixes: #70122
---
 clang/include/clang/Basic/Attr.td             |  7 ++++
 clang/include/clang/Basic/AttrDocs.td         | 11 +++++++
 clang/include/clang/Sema/SemaHLSL.h           |  1 +
 clang/lib/CodeGen/CGHLSLRuntime.cpp           |  5 +++
 clang/lib/Parse/ParseHLSL.cpp                 |  1 +
 clang/lib/Sema/SemaDeclAttr.cpp               |  3 ++
 clang/lib/Sema/SemaHLSL.cpp                   | 10 ++++++
 .../semantics/SV_GroupThreadID.hlsl           | 32 +++++++++++++++++++
 .../SemaHLSL/Semantics/entry_parameter.hlsl   | 13 +++++---
 .../Semantics/invalid_entry_parameter.hlsl    | 22 +++++++++++++
 .../Semantics/valid_entry_parameter.hlsl      | 25 +++++++++++++++
 11 files changed, 125 insertions(+), 5 deletions(-)
 create mode 100644 clang/test/CodeGenHLSL/semantics/SV_GroupThreadID.hlsl

diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index 17fc36fbe2ac8c..90d2a2056fe1ba 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4651,6 +4651,13 @@ def HLSLNumThreads: InheritableAttr {
   let Documentation = [NumThreadsDocs];
 }
 
+def HLSLSV_GroupThreadID: HLSLAnnotationAttr {
+  let Spellings = [HLSLAnnotation<"SV_GroupThreadID">];
+  let Subjects = SubjectList<[ParmVar, Field]>;
+  let LangOpts = [HLSL];
+  let Documentation = [HLSLSV_GroupThreadIDDocs];
+}
+
 def HLSLSV_GroupID: HLSLAnnotationAttr {
   let Spellings = [HLSLAnnotation<"SV_GroupID">];
   let Subjects = SubjectList<[ParmVar, Field]>;
diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td
index 7a82b8fa320590..fdad4c9a3ea191 100644
--- a/clang/include/clang/Basic/AttrDocs.td
+++ b/clang/include/clang/Basic/AttrDocs.td
@@ -7941,6 +7941,17 @@ randomized.
   }];
 }
 
+def HLSLSV_GroupThreadIDDocs : Documentation {
+  let Category = DocCatFunction;
+  let Content = [{
+The ``SV_GroupThreadID`` semantic, when applied to an input parameter, specifies which
+individual thread within a thread group is executing in. This attribute is
+only supported in compute shaders.
+
+The full documentation is available here: https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sv-groupthreadid
+  }];
+}
+
 def HLSLSV_GroupIDDocs : Documentation {
   let Category = DocCatFunction;
   let Content = [{
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index ee685d95c96154..f4cd11f423a84a 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -119,6 +119,7 @@ class SemaHLSL : public SemaBase {
   void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL);
   void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL);
   void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL);
+  void handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL);
   void handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL);
   void handlePackOffsetAttr(Decl *D, const ParsedAttr &AL);
   void handleShaderAttr(Decl *D, const ParsedAttr &AL);
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp
index 2c293523fca8ca..19db7faddaeac0 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.cpp
+++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp
@@ -389,6 +389,11 @@ llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B,
         CGM.getIntrinsic(getThreadIdIntrinsic());
     return buildVectorInput(B, ThreadIDIntrinsic, Ty);
   }
+  if (D.hasAttr<HLSLSV_GroupThreadIDAttr>()) {
+    llvm::Function *GroupThreadIDIntrinsic =
+        CGM.getIntrinsic(Intrinsic::dx_thread_id_in_group);
+    return buildVectorInput(B, GroupThreadIDIntrinsic, Ty);
+  }
   if (D.hasAttr<HLSLSV_GroupIDAttr>()) {
     llvm::Function *GroupIDIntrinsic = CGM.getIntrinsic(Intrinsic::dx_group_id);
     return buildVectorInput(B, GroupIDIntrinsic, Ty);
diff --git a/clang/lib/Parse/ParseHLSL.cpp b/clang/lib/Parse/ParseHLSL.cpp
index 4de342b63ed802..443bf2b9ec626a 100644
--- a/clang/lib/Parse/ParseHLSL.cpp
+++ b/clang/lib/Parse/ParseHLSL.cpp
@@ -280,6 +280,7 @@ void Parser::ParseHLSLAnnotations(ParsedAttributes &Attrs,
   case ParsedAttr::UnknownAttribute:
     Diag(Loc, diag::err_unknown_hlsl_semantic) << II;
     return;
+  case ParsedAttr::AT_HLSLSV_GroupThreadID:
   case ParsedAttr::AT_HLSLSV_GroupID:
   case ParsedAttr::AT_HLSLSV_GroupIndex:
   case ParsedAttr::AT_HLSLSV_DispatchThreadID:
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index 4fd8ef6dbebf84..5d7ee097383771 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -7114,6 +7114,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
   case ParsedAttr::AT_HLSLWaveSize:
     S.HLSL().handleWaveSizeAttr(D, AL);
     break;
+  case ParsedAttr::AT_HLSLSV_GroupThreadID:
+    S.HLSL().handleSV_GroupThreadIDAttr(D, AL);
+    break;
   case ParsedAttr::AT_HLSLSV_GroupID:
     S.HLSL().handleSV_GroupIDAttr(D, AL);
     break;
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 88db3e12541193..600c800029fd05 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -434,6 +434,7 @@ void SemaHLSL::CheckSemanticAnnotation(
   switch (AnnotationAttr->getKind()) {
   case attr::HLSLSV_DispatchThreadID:
   case attr::HLSLSV_GroupIndex:
+  case attr::HLSLSV_GroupThreadID:
   case attr::HLSLSV_GroupID:
     if (ST == llvm::Triple::Compute)
       return;
@@ -787,6 +788,15 @@ void SemaHLSL::handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL) {
                  HLSLSV_DispatchThreadIDAttr(getASTContext(), AL));
 }
 
+void SemaHLSL::handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL) {
+  auto *VD = cast<ValueDecl>(D);
+  if (!diagnoseInputIDType(VD->getType(), AL))
+    return;
+
+  D->addAttr(::new (getASTContext())
+                 HLSLSV_GroupThreadIDAttr(getASTContext(), AL));
+}
+
 void SemaHLSL::handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL) {
   auto *VD = cast<ValueDecl>(D);
   if (!diagnoseInputIDType(VD->getType(), AL))
diff --git a/clang/test/CodeGenHLSL/semantics/SV_GroupThreadID.hlsl b/clang/test/CodeGenHLSL/semantics/SV_GroupThreadID.hlsl
new file mode 100644
index 00000000000000..3533331c6f091c
--- /dev/null
+++ b/clang/test/CodeGenHLSL/semantics/SV_GroupThreadID.hlsl
@@ -0,0 +1,32 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s
+
+// Make sure SV_GroupThreadID translated into dx.thread.id.in.group.
+
+// CHECK:  define void @foo()
+// CHECK:  %[[#ID:]] = call i32 @llvm.dx.thread.id.in.group(i32 0)
+// CHECK:  call void @{{.*}}foo{{.*}}(i32 %[[#ID]])
+[shader("compute")]
+[numthreads(8,8,1)]
+void foo(uint Idx : SV_GroupThreadID) {}
+
+// CHECK:  define void @bar()
+// CHECK:  %[[#ID_X:]] = call i32 @llvm.dx.thread.id.in.group(i32 0)
+// CHECK:  %[[#ID_X_:]] = insertelement <2 x i32> poison, i32 %[[#ID_X]], i64 0
+// CHECK:  %[[#ID_Y:]] = call i32 @llvm.dx.thread.id.in.group(i32 1)
+// CHECK:  %[[#ID_XY:]] = insertelement <2 x i32> %[[#ID_X_]], i32 %[[#ID_Y]], i64 1
+// CHECK:  call void @{{.*}}bar{{.*}}(<2 x i32> %[[#ID_XY]])
+[shader("compute")]
+[numthreads(8,8,1)]
+void bar(uint2 Idx : SV_GroupThreadID) {}
+
+// CHECK:  define void @test()
+// CHECK:  %[[#ID_X:]] = call i32 @llvm.dx.thread.id.in.group(i32 0)
+// CHECK:  %[[#ID_X_:]] = insertelement <3 x i32> poison, i32 %[[#ID_X]], i64 0
+// CHECK:  %[[#ID_Y:]] = call i32 @llvm.dx.thread.id.in.group(i32 1)
+// CHECK:  %[[#ID_XY:]] = insertelement <3 x i32> %[[#ID_X_]], i32 %[[#ID_Y]], i64 1
+// CHECK:  %[[#ID_Z:]] = call i32 @llvm.dx.thread.id.in.group(i32 2)
+// CHECK:  %[[#ID_XYZ:]] = insertelement <3 x i32> %[[#ID_XY]], i32 %[[#ID_Z]], i64 2
+// CHECK:  call void @{{.*}}test{{.*}}(<3 x i32> %[[#ID_XYZ]])
+[shader("compute")]
+[numthreads(8,8,1)]
+void test(uint3 Idx : SV_GroupThreadID) {}
diff --git a/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl b/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl
index 13c07038d2e4a4..71d32cd13832e1 100644
--- a/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl
+++ b/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl
@@ -2,15 +2,18 @@
 // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-mesh -hlsl-entry CSMain -x hlsl -finclude-default-header  -verify -o - %s
 
 [numthreads(8,8,1)]
-// expected-error at +3 {{attribute 'SV_GroupIndex' is unsupported in 'mesh' shaders, requires compute}}
-// expected-error at +2 {{attribute 'SV_DispatchThreadID' is unsupported in 'mesh' shaders, requires compute}}
-// expected-error at +1 {{attribute 'SV_GroupID' is unsupported in 'mesh' shaders, requires compute}}
-void CSMain(int GI : SV_GroupIndex, uint ID : SV_DispatchThreadID, uint GID : SV_GroupID) {
-// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain 'void (int, uint, uint)'
+// expected-error at +4 {{attribute 'SV_GroupIndex' is unsupported in 'mesh' shaders, requires compute}}
+// expected-error at +3 {{attribute 'SV_DispatchThreadID' is unsupported in 'mesh' shaders, requires compute}}
+// expected-error at +2 {{attribute 'SV_GroupID' is unsupported in 'mesh' shaders, requires compute}}
+// expected-error at +1 {{attribute 'SV_GroupThreadID' is unsupported in 'mesh' shaders, requires compute}}
+void CSMain(int GI : SV_GroupIndex, uint ID : SV_DispatchThreadID, uint GID : SV_GroupID, uint GThreadID : SV_GroupThreadID) {
+// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain 'void (int, uint, uint, uint)'
 // CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:17 GI 'int'
 // CHECK-NEXT: HLSLSV_GroupIndexAttr
 // CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:42 ID 'uint'
 // CHECK-NEXT: HLSLSV_DispatchThreadIDAttr
 // CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:73 GID 'uint'
 // CHECK-NEXT: HLSLSV_GroupIDAttr
+// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:96 GThreadID 'uint'
+// CHECK-NEXT: HLSLSV_GroupThreadIDAttr
 }
diff --git a/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl b/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl
index 4e1f88aa2294b5..a24112c8e1bb8f 100644
--- a/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl
+++ b/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl
@@ -49,3 +49,25 @@ struct ST2_GID {
     static uint GID : SV_GroupID;
     uint s_gid : SV_GroupID;
 };
+
+[numthreads(8,8,1)]
+// expected-error at +1 {{attribute 'SV_GroupThreadID' only applies to a field or parameter of type 'uint/uint2/uint3'}}
+void CSMain_GThreadID(float ID : SV_GroupThreadID) {
+}
+
+[numthreads(8,8,1)]
+// expected-error at +1 {{attribute 'SV_GroupThreadID' only applies to a field or parameter of type 'uint/uint2/uint3'}}
+void CSMain2_GThreadID(ST GID : SV_GroupThreadID) {
+
+}
+
+void foo_GThreadID() {
+// expected-warning at +1 {{'SV_GroupThreadID' attribute only applies to parameters and non-static data members}}
+  uint GThreadIS : SV_GroupThreadID;
+}
+
+struct ST2_GThreadID {
+// expected-warning at +1 {{'SV_GroupThreadID' attribute only applies to parameters and non-static data members}}
+    static uint GThreadID : SV_GroupThreadID;
+    uint s_gthreadid : SV_GroupThreadID;
+};
diff --git a/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl b/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl
index 10a5e5dabac87b..6781f9241df240 100644
--- a/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl
+++ b/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl
@@ -49,3 +49,28 @@ void CSMain3_GID(uint3 : SV_GroupID) {
 // CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:24 'uint3'
 // CHECK-NEXT: HLSLSV_GroupIDAttr
 }
+
+[numthreads(8,8,1)]
+void CSMain_GThreadID(uint ID : SV_GroupThreadID) {
+// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain_GThreadID 'void (uint)'
+// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:28 ID 'uint'
+// CHECK-NEXT: HLSLSV_GroupThreadIDAttr
+}
+[numthreads(8,8,1)]
+void CSMain1_GThreadID(uint2 ID : SV_GroupThreadID) {
+// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain1_GThreadID 'void (uint2)'
+// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:30 ID 'uint2'
+// CHECK-NEXT: HLSLSV_GroupThreadIDAttr
+}
+[numthreads(8,8,1)]
+void CSMain2_GThreadID(uint3 ID : SV_GroupThreadID) {
+// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain2_GThreadID 'void (uint3)'
+// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:30 ID 'uint3'
+// CHECK-NEXT: HLSLSV_GroupThreadIDAttr
+}
+[numthreads(8,8,1)]
+void CSMain3_GThreadID(uint3 : SV_GroupThreadID) {
+// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain3_GThreadID 'void (uint3)'
+// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:30 'uint3'
+// CHECK-NEXT: HLSLSV_GroupThreadIDAttr
+}

>From dc8d779f067e5b8d22e56036c4ba7320e297f339 Mon Sep 17 00:00:00 2001
From: Zhengxing Li <zhengxingli at microsoft.com>
Date: Thu, 5 Dec 2024 10:54:54 -0800
Subject: [PATCH 2/3] Don't test the Group/Thread input IDs with mesh shader

The SV_GroupIndex, SV_DispatchThreadID, SV_GroupID and SV_GroupThreadID are actually legal for meash shader stage. It shouldn't test them with mesh shader.

This commit tests them with vertex shader and move the test into invalid_entry_parameter.hlsl which's a better place for it.
---
 clang/test/SemaHLSL/Semantics/entry_parameter.hlsl        | 5 -----
 .../test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl  | 8 ++++++++
 2 files changed, 8 insertions(+), 5 deletions(-)

diff --git a/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl b/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl
index 71d32cd13832e1..393d7300605c09 100644
--- a/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl
+++ b/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl
@@ -1,11 +1,6 @@
 // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -hlsl-entry CSMain -x hlsl  -finclude-default-header  -ast-dump -o - %s | FileCheck %s
-// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-mesh -hlsl-entry CSMain -x hlsl -finclude-default-header  -verify -o - %s
 
 [numthreads(8,8,1)]
-// expected-error at +4 {{attribute 'SV_GroupIndex' is unsupported in 'mesh' shaders, requires compute}}
-// expected-error at +3 {{attribute 'SV_DispatchThreadID' is unsupported in 'mesh' shaders, requires compute}}
-// expected-error at +2 {{attribute 'SV_GroupID' is unsupported in 'mesh' shaders, requires compute}}
-// expected-error at +1 {{attribute 'SV_GroupThreadID' is unsupported in 'mesh' shaders, requires compute}}
 void CSMain(int GI : SV_GroupIndex, uint ID : SV_DispatchThreadID, uint GID : SV_GroupID, uint GThreadID : SV_GroupThreadID) {
 // CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain 'void (int, uint, uint, uint)'
 // CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:17 GI 'int'
diff --git a/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl b/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl
index a24112c8e1bb8f..1bb4ee5182d621 100644
--- a/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl
+++ b/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl
@@ -71,3 +71,11 @@ struct ST2_GThreadID {
     static uint GThreadID : SV_GroupThreadID;
     uint s_gthreadid : SV_GroupThreadID;
 };
+
+
+[shader("vertex")]
+// expected-error at +4 {{attribute 'SV_GroupIndex' is unsupported in 'vertex' shaders, requires compute}}
+// expected-error at +3 {{attribute 'SV_DispatchThreadID' is unsupported in 'vertex' shaders, requires compute}}
+// expected-error at +2 {{attribute 'SV_GroupID' is unsupported in 'vertex' shaders, requires compute}}
+// expected-error at +1 {{attribute 'SV_GroupThreadID' is unsupported in 'vertex' shaders, requires compute}}
+void vs_main(int GI : SV_GroupIndex, uint ID : SV_DispatchThreadID, uint GID : SV_GroupID, uint GThreadID : SV_GroupThreadID) {}

>From 28f823454873d4bc029f2ec57bed3a9707bbf1b2 Mon Sep 17 00:00:00 2001
From: Zhengxing Li <zhengxingli at microsoft.com>
Date: Thu, 5 Dec 2024 15:06:01 -0800
Subject: [PATCH 3/3] [HLSL][SPIR-V] Add SV_GroupThreadID semantic support

The HLSL SV_GroupThreadID semantic attribute is lowered into @llvm.spv.thread.id.in.group intrinsic in LLVM IR for SPIR-V target.

In the SPIR-V backend, this is now correctly translated to a `LocalInvocationId` builtin variable.

Fixes #70122
---
 clang/lib/CodeGen/CGHLSLRuntime.cpp           |  2 +-
 clang/lib/CodeGen/CGHLSLRuntime.h             |  1 +
 .../semantics/SV_GroupThreadID.hlsl           | 34 +++++----
 llvm/include/llvm/IR/IntrinsicsSPIRV.td       |  1 +
 .../Target/SPIRV/SPIRVInstructionSelector.cpp | 59 ++++++++++----
 .../SPIRV/hlsl-intrinsics/SV_GroupThreadID.ll | 76 +++++++++++++++++++
 6 files changed, 144 insertions(+), 29 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/SV_GroupThreadID.ll

diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp
index 19db7faddaeac0..fb15b1993e74ad 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.cpp
+++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp
@@ -391,7 +391,7 @@ llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B,
   }
   if (D.hasAttr<HLSLSV_GroupThreadIDAttr>()) {
     llvm::Function *GroupThreadIDIntrinsic =
-        CGM.getIntrinsic(Intrinsic::dx_thread_id_in_group);
+        CGM.getIntrinsic(getGroupThreadIdIntrinsic());
     return buildVectorInput(B, GroupThreadIDIntrinsic, Ty);
   }
   if (D.hasAttr<HLSLSV_GroupIDAttr>()) {
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index bb120c8b5e9e60..f9efb1bc996412 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -86,6 +86,7 @@ class CGHLSLRuntime {
   GENERATE_HLSL_INTRINSIC_FUNCTION(Step, step)
   GENERATE_HLSL_INTRINSIC_FUNCTION(Radians, radians)
   GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id)
+  GENERATE_HLSL_INTRINSIC_FUNCTION(GroupThreadId, thread_id_in_group)
   GENERATE_HLSL_INTRINSIC_FUNCTION(FDot, fdot)
   GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
   GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
diff --git a/clang/test/CodeGenHLSL/semantics/SV_GroupThreadID.hlsl b/clang/test/CodeGenHLSL/semantics/SV_GroupThreadID.hlsl
index 3533331c6f091c..3d347b973f39c8 100644
--- a/clang/test/CodeGenHLSL/semantics/SV_GroupThreadID.hlsl
+++ b/clang/test/CodeGenHLSL/semantics/SV_GroupThreadID.hlsl
@@ -1,32 +1,36 @@
-// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-DXIL -DTARGET=dx
+// RUN: %clang_cc1 -triple spirv-linux-vulkan-library -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV -DTARGET=spv
 
-// Make sure SV_GroupThreadID translated into dx.thread.id.in.group.
+// Make sure SV_GroupThreadID translated into dx.thread.id.in.group for directx target and spv.thread.id.in.group for spirv target.
 
-// CHECK:  define void @foo()
-// CHECK:  %[[#ID:]] = call i32 @llvm.dx.thread.id.in.group(i32 0)
-// CHECK:  call void @{{.*}}foo{{.*}}(i32 %[[#ID]])
+// CHECK:       define void @foo()
+// CHECK:       %[[#ID:]] = call i32 @llvm.[[TARGET]].thread.id.in.group(i32 0)
+// CHECK-DXIL:       call void @{{.*}}foo{{.*}}(i32 %[[#ID]])
+// CHECK-SPIRV:      call spir_func void @{{.*}}foo{{.*}}(i32 %[[#ID]])
 [shader("compute")]
 [numthreads(8,8,1)]
 void foo(uint Idx : SV_GroupThreadID) {}
 
-// CHECK:  define void @bar()
-// CHECK:  %[[#ID_X:]] = call i32 @llvm.dx.thread.id.in.group(i32 0)
-// CHECK:  %[[#ID_X_:]] = insertelement <2 x i32> poison, i32 %[[#ID_X]], i64 0
-// CHECK:  %[[#ID_Y:]] = call i32 @llvm.dx.thread.id.in.group(i32 1)
-// CHECK:  %[[#ID_XY:]] = insertelement <2 x i32> %[[#ID_X_]], i32 %[[#ID_Y]], i64 1
-// CHECK:  call void @{{.*}}bar{{.*}}(<2 x i32> %[[#ID_XY]])
+// CHECK:       define void @bar()
+// CHECK:       %[[#ID_X:]] = call i32 @llvm.[[TARGET]].thread.id.in.group(i32 0)
+// CHECK:       %[[#ID_X_:]] = insertelement <2 x i32> poison, i32 %[[#ID_X]], i64 0
+// CHECK:       %[[#ID_Y:]] = call i32 @llvm.[[TARGET]].thread.id.in.group(i32 1)
+// CHECK:       %[[#ID_XY:]] = insertelement <2 x i32> %[[#ID_X_]], i32 %[[#ID_Y]], i64 1
+// CHECK-DXIL:  call void @{{.*}}bar{{.*}}(<2 x i32> %[[#ID_XY]])
+// CHECK-SPIRV:  call spir_func void @{{.*}}bar{{.*}}(<2 x i32> %[[#ID_XY]])
 [shader("compute")]
 [numthreads(8,8,1)]
 void bar(uint2 Idx : SV_GroupThreadID) {}
 
 // CHECK:  define void @test()
-// CHECK:  %[[#ID_X:]] = call i32 @llvm.dx.thread.id.in.group(i32 0)
+// CHECK:  %[[#ID_X:]] = call i32 @llvm.[[TARGET]].thread.id.in.group(i32 0)
 // CHECK:  %[[#ID_X_:]] = insertelement <3 x i32> poison, i32 %[[#ID_X]], i64 0
-// CHECK:  %[[#ID_Y:]] = call i32 @llvm.dx.thread.id.in.group(i32 1)
+// CHECK:  %[[#ID_Y:]] = call i32 @llvm.[[TARGET]].thread.id.in.group(i32 1)
 // CHECK:  %[[#ID_XY:]] = insertelement <3 x i32> %[[#ID_X_]], i32 %[[#ID_Y]], i64 1
-// CHECK:  %[[#ID_Z:]] = call i32 @llvm.dx.thread.id.in.group(i32 2)
+// CHECK:  %[[#ID_Z:]] = call i32 @llvm.[[TARGET]].thread.id.in.group(i32 2)
 // CHECK:  %[[#ID_XYZ:]] = insertelement <3 x i32> %[[#ID_XY]], i32 %[[#ID_Z]], i64 2
-// CHECK:  call void @{{.*}}test{{.*}}(<3 x i32> %[[#ID_XYZ]])
+// CHECK-DXIL:   call void @{{.*}}test{{.*}}(<3 x i32> %[[#ID_XYZ]])
+// CHECK-SPIRV:  call spir_func void @{{.*}}test{{.*}}(<3 x i32> %[[#ID_XYZ]])
 [shader("compute")]
 [numthreads(8,8,1)]
 void test(uint3 Idx : SV_GroupThreadID) {}
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 1ae3129774e507..fd0c3b2a59e1db 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -59,6 +59,7 @@ let TargetPrefix = "spv" in {
 
   // The following intrinsic(s) are mirrored from IntrinsicsDirectX.td for HLSL support.
   def int_spv_thread_id : Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem, IntrWillReturn]>;
+  def int_spv_thread_id_in_group : Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem, IntrWillReturn]>;
   def int_spv_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrNoMem]>;
   def int_spv_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrNoMem]>;
   def int_spv_cross : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 3a98b74b3d6757..9c831028523fbf 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -265,6 +265,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
   bool selectSpvThreadId(Register ResVReg, const SPIRVType *ResType,
                          MachineInstr &I) const;
 
+  bool selectSpvGroupThreadId(Register ResVReg, const SPIRVType *ResType,
+                              MachineInstr &I) const;
+
   bool selectWaveOpInst(Register ResVReg, const SPIRVType *ResType,
                         MachineInstr &I, unsigned Opcode) const;
 
@@ -309,6 +312,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
   SPIRVType *widenTypeToVec4(const SPIRVType *Type, MachineInstr &I) const;
   void extractSubvector(Register &ResVReg, const SPIRVType *ResType,
                         Register &ReadReg, MachineInstr &InsertionPoint) const;
+  bool loadVec3BuiltinInputID(SPIRV::BuiltIn::BuiltIn BuiltInValue,
+                              Register ResVReg, const SPIRVType *ResType,
+                              MachineInstr &I) const;
 };
 
 } // end anonymous namespace
@@ -2852,6 +2858,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
     break;
   case Intrinsic::spv_thread_id:
     return selectSpvThreadId(ResVReg, ResType, I);
+  case Intrinsic::spv_thread_id_in_group:
+    return selectSpvGroupThreadId(ResVReg, ResType, I);
   case Intrinsic::spv_fdot:
     return selectFloatDot(ResVReg, ResType, I);
   case Intrinsic::spv_udot:
@@ -3551,13 +3559,12 @@ bool SPIRVInstructionSelector::selectLog10(Register ResVReg,
                        .constrainAllUses(TII, TRI, RBI);
 }
 
-bool SPIRVInstructionSelector::selectSpvThreadId(Register ResVReg,
-                                                 const SPIRVType *ResType,
-                                                 MachineInstr &I) const {
-  // DX intrinsic: @llvm.dx.thread.id(i32)
-  // ID  Name      Description
-  // 93  ThreadId  reads the thread ID
-
+// Generate the instructions to load 3-element vector builtin input
+// IDs/Indices.
+// Like: SV_DispatchThreadID, SV_GroupThreadID, etc....
+bool SPIRVInstructionSelector::loadVec3BuiltinInputID(
+    SPIRV::BuiltIn::BuiltIn BuiltInValue, Register ResVReg,
+    const SPIRVType *ResType, MachineInstr &I) const {
   MachineIRBuilder MIRBuilder(I);
   const SPIRVType *U32Type = GR.getOrCreateSPIRVIntegerType(32, MIRBuilder);
   const SPIRVType *Vec3Ty =
@@ -3565,16 +3572,16 @@ bool SPIRVInstructionSelector::selectSpvThreadId(Register ResVReg,
   const SPIRVType *PtrType = GR.getOrCreateSPIRVPointerType(
       Vec3Ty, MIRBuilder, SPIRV::StorageClass::Input);
 
-  // Create new register for GlobalInvocationID builtin variable.
+  // Create new register for the input ID builtin variable.
   Register NewRegister =
       MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::iIDRegClass);
   MIRBuilder.getMRI()->setType(NewRegister, LLT::pointer(0, 64));
   GR.assignSPIRVTypeToVReg(PtrType, NewRegister, MIRBuilder.getMF());
 
-  // Build GlobalInvocationID global variable with the necessary decorations.
+  // Build global variable with the necessary decorations for the input ID
+  // builtin variable.
   Register Variable = GR.buildGlobalVariable(
-      NewRegister, PtrType,
-      getLinkStringForBuiltIn(SPIRV::BuiltIn::GlobalInvocationId), nullptr,
+      NewRegister, PtrType, getLinkStringForBuiltIn(BuiltInValue), nullptr,
       SPIRV::StorageClass::Input, nullptr, true, true,
       SPIRV::LinkageType::Import, MIRBuilder, false);
 
@@ -3591,12 +3598,12 @@ bool SPIRVInstructionSelector::selectSpvThreadId(Register ResVReg,
           .addUse(GR.getSPIRVTypeID(Vec3Ty))
           .addUse(Variable);
 
-  // Get Thread ID index. Expecting operand is a constant immediate value,
+  // Get the input ID index. Expecting operand is a constant immediate value,
   // wrapped in a type assignment.
   assert(I.getOperand(2).isReg());
   const uint32_t ThreadId = foldImm(I.getOperand(2), MRI);
 
-  // Extract the thread ID from the loaded vector value.
+  // Extract the input ID from the loaded vector value.
   MachineBasicBlock &BB = *I.getParent();
   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
                  .addDef(ResVReg)
@@ -3606,6 +3613,32 @@ bool SPIRVInstructionSelector::selectSpvThreadId(Register ResVReg,
   return Result && MIB.constrainAllUses(TII, TRI, RBI);
 }
 
+bool SPIRVInstructionSelector::selectSpvThreadId(Register ResVReg,
+                                                 const SPIRVType *ResType,
+                                                 MachineInstr &I) const {
+  // DX intrinsic: @llvm.dx.thread.id(i32)
+  // ID  Name      Description
+  // 93  ThreadId  reads the thread ID
+  //
+  // In SPIR-V, llvm.dx.thread.id maps to a `GlobalInvocationId` builtin
+  // variable
+  return loadVec3BuiltinInputID(SPIRV::BuiltIn::GlobalInvocationId, ResVReg,
+                                ResType, I);
+}
+
+bool SPIRVInstructionSelector::selectSpvGroupThreadId(Register ResVReg,
+                                                      const SPIRVType *ResType,
+                                                      MachineInstr &I) const {
+  // DX intrinsic: @llvm.dx.thread.id.in.group(i32)
+  // ID  Name           Description
+  // 95  GroupThreadId  Reads the thread ID within the group
+  //
+  // In SPIR-V, llvm.dx.thread.id.in.group maps to a `LocalInvocationId` builtin
+  // variable
+  return loadVec3BuiltinInputID(SPIRV::BuiltIn::LocalInvocationId, ResVReg,
+                                ResType, I);
+}
+
 SPIRVType *SPIRVInstructionSelector::widenTypeToVec4(const SPIRVType *Type,
                                                      MachineInstr &I) const {
   MachineIRBuilder MIRBuilder(I);
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/SV_GroupThreadID.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/SV_GroupThreadID.ll
new file mode 100644
index 00000000000000..a88debf97fa7bb
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/SV_GroupThreadID.ll
@@ -0,0 +1,76 @@
+; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv-vulkan-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-unknown %s -o - -filetype=obj | spirv-val %}
+
+; This file generated from the following command:
+; clang -cc1 -triple spirv-vulkan-library -x hlsl -emit-llvm -disable-llvm-passes -finclude-default-header - -o - <<EOF
+; [shader("compute")]
+; [numthreads(1,1,1)]
+; void main(uint3 ID : SV_GroupThreadID) {}
+; EOF
+
+; CHECK-DAG:        %[[#int:]] = OpTypeInt 32 0
+; CHECK-DAG:        %[[#v3int:]] = OpTypeVector %[[#int]] 3
+; CHECK-DAG:        %[[#ptr_Input_v3int:]] = OpTypePointer Input %[[#v3int]]
+; CHECK-DAG:        %[[#tempvar:]] = OpUndef %[[#v3int]]
+; CHECK-DAG:        %[[#LocalInvocationId:]] = OpVariable %[[#ptr_Input_v3int]] Input
+
+; CHECK-DAG:        OpEntryPoint GLCompute {{.*}} %[[#LocalInvocationId]]
+; CHECK-DAG:        OpName %[[#LocalInvocationId]] "__spirv_BuiltInLocalInvocationId"
+; CHECK-DAG:        OpDecorate %[[#LocalInvocationId]] LinkageAttributes "__spirv_BuiltInLocalInvocationId" Import
+; CHECK-DAG:        OpDecorate %[[#LocalInvocationId]] BuiltIn LocalInvocationId
+
+; ModuleID = '-'
+source_filename = "-"
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
+target triple = "spirv-unknown-vulkan-library"
+
+; Function Attrs: noinline norecurse nounwind optnone
+define internal spir_func void @main(<3 x i32> noundef %ID) #0 {
+entry:
+  %ID.addr = alloca <3 x i32>, align 16
+  store <3 x i32> %ID, ptr %ID.addr, align 16
+  ret void
+}
+
+; Function Attrs: norecurse
+define void @main.1() #1 {
+entry:
+
+; CHECK:        %[[#load:]] = OpLoad %[[#v3int]] %[[#LocalInvocationId]]
+; CHECK:        %[[#load0:]] = OpCompositeExtract %[[#int]] %[[#load]] 0
+  %0 = call i32 @llvm.spv.thread.id.in.group(i32 0)
+
+; CHECK:        %[[#tempvar:]] = OpCompositeInsert %[[#v3int]] %[[#load0]] %[[#tempvar]] 0
+  %1 = insertelement <3 x i32> poison, i32 %0, i64 0
+
+; CHECK:        %[[#load:]] = OpLoad %[[#v3int]] %[[#LocalInvocationId]]
+; CHECK:        %[[#load1:]] = OpCompositeExtract %[[#int]] %[[#load]] 1
+  %2 = call i32 @llvm.spv.thread.id.in.group(i32 1)
+
+; CHECK:        %[[#tempvar:]] = OpCompositeInsert %[[#v3int]] %[[#load1]] %[[#tempvar]] 1
+  %3 = insertelement <3 x i32> %1, i32 %2, i64 1
+
+; CHECK:        %[[#load:]] = OpLoad %[[#v3int]] %[[#LocalInvocationId]]
+; CHECK:        %[[#load2:]] = OpCompositeExtract %[[#int]] %[[#load]] 2
+  %4 = call i32 @llvm.spv.thread.id.in.group(i32 2)
+
+; CHECK:        %[[#tempvar:]] = OpCompositeInsert %[[#v3int]] %[[#load2]] %[[#tempvar]] 2
+  %5 = insertelement <3 x i32> %3, i32 %4, i64 2
+
+  call void @main(<3 x i32> %5)
+  ret void
+}
+
+; Function Attrs: nounwind willreturn memory(none)
+declare i32 @llvm.spv.thread.id.in.group(i32) #2
+
+attributes #0 = { noinline norecurse nounwind optnone "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #1 = { norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #2 = { nounwind willreturn memory(none) }
+
+!llvm.module.flags = !{!0, !1}
+!llvm.ident = !{!2}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 4, !"dx.disable_optimizations", i32 1}
+!2 = !{!"clang version 19.0.0git (git at github.com:llvm/llvm-project.git 91600507765679e92434ec7c5edb883bf01f847f)"}



More information about the cfe-commits mailing list