[llvm] [SPIR-V] Implement insertion of 'Group and Subgroup Instructions' using builtin functions (PR #95176)

Vyacheslav Levytskyy via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 11 14:33:15 PDT 2024


https://github.com/VyacheslavLevytskyy created https://github.com/llvm/llvm-project/pull/95176

This PR adds builtin functions to insert instructions from 'Group and Subgroup' section of the SPIR-V Specification. Corresponding tests are updated, `spirv-val` run is added where it was missed.

>From 1de94d74e4c543904066b86136f38d9d10e41e5d Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Tue, 11 Jun 2024 14:30:39 -0700
Subject: [PATCH] implement insertion of 'Group and Subgroup Instructions'
 using builtin functions-wrappers

---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp       |  5 +++++
 llvm/lib/Target/SPIRV/SPIRVBuiltins.td        | 16 ++++++++++++++
 .../SPIRV/transcoding/OpGroupAllAny.ll        | 21 ++++++++++++++-----
 .../CodeGen/SPIRV/transcoding/group_ops.ll    |  9 +++++++-
 4 files changed, 45 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 6bb3e215240a8..bfe2e01387279 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -1015,6 +1015,11 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
   const SPIRV::GroupBuiltin *GroupBuiltin =
       SPIRV::lookupGroupBuiltin(Builtin->Name);
+
+  if (Call->isSpirvOp())
+    return buildOpFromWrapper(MIRBuilder, GroupBuiltin->Opcode, Call,
+                              GR->getSPIRVTypeID(Call->ReturnType));
+
   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
   Register Arg0;
   if (GroupBuiltin->HasBoolArg) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
index 2edd2992425bd..d93756cc67c9c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
@@ -694,9 +694,17 @@ multiclass DemangledGroupBuiltin<string name, int level /* OnlyWork/OnlySub/...
   }
 }
 
+multiclass DemangledGroupBuiltinWrapper<string name, bits<8> minNumArgs, bits<8> maxNumArgs, Op operation> {
+    def : DemangledBuiltin<name, OpenCL_std, Group, minNumArgs, maxNumArgs>;
+    def : GroupBuiltin<name, operation>;
+}
+
 defm : DemangledGroupBuiltin<"group_all", WorkOrSub, OpGroupAll>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupAll", 2, 2, OpGroupAll>;
 defm : DemangledGroupBuiltin<"group_any", WorkOrSub, OpGroupAny>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupAny", 2, 2, OpGroupAny>;
 defm : DemangledGroupBuiltin<"group_broadcast", WorkOrSub, OpGroupBroadcast>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupBroadcast", 3, 3, OpGroupBroadcast>;
 defm : DemangledGroupBuiltin<"group_non_uniform_broadcast", OnlySub, OpGroupNonUniformBroadcast>;
 defm : DemangledGroupBuiltin<"group_broadcast_first", OnlySub, OpGroupNonUniformBroadcastFirst>;
 
@@ -731,41 +739,49 @@ defm : DemangledGroupBuiltin<"group_scan_inclusive_adds", WorkOrSub, OpGroupIAdd
 defm : DemangledGroupBuiltin<"group_reduce_addu", WorkOrSub, OpGroupIAdd>;
 defm : DemangledGroupBuiltin<"group_scan_exclusive_addu", WorkOrSub, OpGroupIAdd>;
 defm : DemangledGroupBuiltin<"group_scan_inclusive_addu", WorkOrSub, OpGroupIAdd>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupIAdd", 3, 3, OpGroupIAdd>;
 
 defm : DemangledGroupBuiltin<"group_fadd", WorkOrSub, OpGroupFAdd>;
 defm : DemangledGroupBuiltin<"group_reduce_addf", WorkOrSub, OpGroupFAdd>;
 defm : DemangledGroupBuiltin<"group_scan_exclusive_addf", WorkOrSub, OpGroupFAdd>;
 defm : DemangledGroupBuiltin<"group_scan_inclusive_addf", WorkOrSub, OpGroupFAdd>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupFAdd", 3, 3, OpGroupFAdd>;
 
 defm : DemangledGroupBuiltin<"group_fmin", WorkOrSub, OpGroupFMin>;
 defm : DemangledGroupBuiltin<"group_reduce_minf", WorkOrSub, OpGroupFMin>;
 defm : DemangledGroupBuiltin<"group_scan_exclusive_minf", WorkOrSub, OpGroupFMin>;
 defm : DemangledGroupBuiltin<"group_scan_inclusive_minf", WorkOrSub, OpGroupFMin>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupFMin", 3, 3, OpGroupFMin>;
 
 defm : DemangledGroupBuiltin<"group_umin", WorkOrSub, OpGroupUMin>;
 defm : DemangledGroupBuiltin<"group_reduce_minu", WorkOrSub, OpGroupUMin>;
 defm : DemangledGroupBuiltin<"group_scan_exclusive_minu", WorkOrSub, OpGroupUMin>;
 defm : DemangledGroupBuiltin<"group_scan_inclusive_minu", WorkOrSub, OpGroupUMin>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupUMin", 3, 3, OpGroupUMin>;
 
 defm : DemangledGroupBuiltin<"group_smin", WorkOrSub, OpGroupSMin>;
 defm : DemangledGroupBuiltin<"group_reduce_mins", WorkOrSub, OpGroupSMin>;
 defm : DemangledGroupBuiltin<"group_scan_exclusive_mins", WorkOrSub, OpGroupSMin>;
 defm : DemangledGroupBuiltin<"group_scan_inclusive_mins", WorkOrSub, OpGroupSMin>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupSMin", 3, 3, OpGroupSMin>;
 
 defm : DemangledGroupBuiltin<"group_fmax", WorkOrSub, OpGroupFMax>;
 defm : DemangledGroupBuiltin<"group_reduce_maxf", WorkOrSub, OpGroupFMax>;
 defm : DemangledGroupBuiltin<"group_scan_exclusive_maxf", WorkOrSub, OpGroupFMax>;
 defm : DemangledGroupBuiltin<"group_scan_inclusive_maxf", WorkOrSub, OpGroupFMax>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupFMax", 3, 3, OpGroupFMax>;
 
 defm : DemangledGroupBuiltin<"group_umax", WorkOrSub, OpGroupUMax>;
 defm : DemangledGroupBuiltin<"group_reduce_maxu", WorkOrSub, OpGroupUMax>;
 defm : DemangledGroupBuiltin<"group_scan_exclusive_maxu", WorkOrSub, OpGroupUMax>;
 defm : DemangledGroupBuiltin<"group_scan_inclusive_maxu", WorkOrSub, OpGroupUMax>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupUMax", 3, 3, OpGroupUMax>;
 
 defm : DemangledGroupBuiltin<"group_smax", WorkOrSub, OpGroupSMax>;
 defm : DemangledGroupBuiltin<"group_reduce_maxs", WorkOrSub, OpGroupSMax>;
 defm : DemangledGroupBuiltin<"group_scan_exclusive_maxs", WorkOrSub, OpGroupSMax>;
 defm : DemangledGroupBuiltin<"group_scan_inclusive_maxs", WorkOrSub, OpGroupSMax>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupSMax", 3, 3, OpGroupSMax>;
 
 // cl_khr_subgroup_non_uniform_arithmetic
 defm : DemangledGroupBuiltin<"group_non_uniform_iadd", WorkOrSub, OpGroupNonUniformIAdd>;
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/OpGroupAllAny.ll b/llvm/test/CodeGen/SPIRV/transcoding/OpGroupAllAny.ll
index 5f11ca2a474e9..39f9e0581d2f8 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/OpGroupAllAny.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/OpGroupAllAny.ll
@@ -1,18 +1,29 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
 ; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
 
 ; CHECK-SPIRV: OpCapability Groups
-; CHECK-SPIRV: %[[#BoolTypeID:]] = OpTypeBool
-; CHECK-SPIRV: %[[#ConstID:]] = OpConstantTrue %[[#BoolTypeID]]
-; CHECK-SPIRV: %[[#]] = OpGroupAll %[[#BoolTypeID]] %[[#]] %[[#ConstID]]
-; CHECK-SPIRV: %[[#]] = OpGroupAny %[[#BoolTypeID]] %[[#]] %[[#ConstID]]
+; CHECK-SPIRV-DAG: %[[#BoolTypeID:]] = OpTypeBool
+; CHECK-SPIRV-DAG: %[[#True:]] = OpConstantTrue %[[#BoolTypeID]]
+; CHECK-SPIRV-DAG: %[[#False:]] = OpConstantFalse %[[#BoolTypeID]]
+; CHECK-SPIRV: %[[#]] = OpGroupAll %[[#BoolTypeID]] %[[#]] %[[#True]]
+; CHECK-SPIRV: %[[#]] = OpGroupAny %[[#BoolTypeID]] %[[#]] %[[#True]]
+; CHECK-SPIRV: %[[#]] = OpGroupAll %[[#BoolTypeID]] %[[#]] %[[#True]]
+; CHECK-SPIRV: %[[#]] = OpGroupAny %[[#BoolTypeID]] %[[#]] %[[#False]]
 
 define spir_kernel void @test(i32 addrspace(1)* nocapture readnone %i) {
 entry:
   %call = tail call spir_func i32 @_Z14work_group_alli(i32 5)
   %call1 = tail call spir_func i32 @_Z14work_group_anyi(i32 5)
+  %call3 = tail call spir_func i32 @__spirv_GroupAll(i32 0, i1 1)
+  %call4 = tail call spir_func i32 @__spirv_GroupAny(i32 0, i1 0)
   ret void
 }
 
 declare spir_func i32 @_Z14work_group_alli(i32)
-
 declare spir_func i32 @_Z14work_group_anyi(i32)
+
+declare spir_func i1 @__spirv_GroupAll(i32, i1)
+declare spir_func i1 @__spirv_GroupAny(i32, i1)
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/group_ops.ll b/llvm/test/CodeGen/SPIRV/transcoding/group_ops.ll
index 2412f406a9c62..a0a8c64faedd7 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/group_ops.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/group_ops.ll
@@ -1,8 +1,12 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
 ; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
 ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
 
 ; CHECK-SPIRV-DAG: %[[#int:]] = OpTypeInt 32 0
 ; CHECK-SPIRV-DAG: %[[#float:]] = OpTypeFloat 32
+; CHECK-SPIRV-DAG: %[[#ScopeCrossWorkgroup:]] = OpConstant %[[#int]] 0
 ; CHECK-SPIRV-DAG: %[[#ScopeWorkgroup:]] = OpConstant %[[#int]] 2
 ; CHECK-SPIRV-DAG: %[[#ScopeSubgroup:]] = OpConstant %[[#int]] 3
 
@@ -247,7 +251,8 @@ entry:
 declare spir_func i32 @_Z21work_group_reduce_minj(i32 noundef) local_unnamed_addr
 
 ; CHECK-SPIRV: OpFunction
-; CHECK-SPIRV: %[[#]] = OpGroupBroadcast %[[#int]] %[[#ScopeWorkgroup]]
+; CHECK-SPIRV: %[[#]] = OpGroupBroadcast %[[#int]] %[[#ScopeWorkgroup]] %[[#BroadcastValue:]] %[[#BroadcastLocalId:]]
+; CHECK-SPIRV: %[[#]] = OpGroupBroadcast %[[#int]] %[[#ScopeCrossWorkgroup]] %[[#BroadcastValue]] %[[#BroadcastLocalId]]
 ; CHECK-SPIRV: OpFunctionEnd
 
 ;; kernel void testWorkGroupBroadcast(uint a, global size_t *id, global int *res) {
@@ -259,7 +264,9 @@ entry:
   %0 = load i32, i32 addrspace(1)* %id, align 4
   %call = call spir_func i32 @_Z20work_group_broadcastjj(i32 noundef %a, i32 noundef %0)
   store i32 %call, i32 addrspace(1)* %res, align 4
+  %call1 = call spir_func i32 @__spirv_GroupBroadcast(i32 0, i32 noundef %a, i32 noundef %0)
   ret void
 }
 
 declare spir_func i32 @_Z20work_group_broadcastjj(i32 noundef, i32 noundef) local_unnamed_addr
+declare spir_func i32 @__spirv_GroupBroadcast(i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr



More information about the llvm-commits mailing list