[llvm] [SPIR-V] Implement insertion of 'Group and Subgroup Instructions' using builtin functions (PR #95176)

Vyacheslav Levytskyy via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 12 03:35:43 PDT 2024


https://github.com/VyacheslavLevytskyy updated https://github.com/llvm/llvm-project/pull/95176

>From 1de94d74e4c543904066b86136f38d9d10e41e5d Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Tue, 11 Jun 2024 14:30:39 -0700
Subject: [PATCH 1/2] implement insertion of 'Group and Subgroup Instructions'
 using builtin functions-wrappers

---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp       |  5 +++++
 llvm/lib/Target/SPIRV/SPIRVBuiltins.td        | 16 ++++++++++++++
 .../SPIRV/transcoding/OpGroupAllAny.ll        | 21 ++++++++++++++-----
 .../CodeGen/SPIRV/transcoding/group_ops.ll    |  9 +++++++-
 4 files changed, 45 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 6bb3e215240a8..bfe2e01387279 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -1015,6 +1015,11 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
   const SPIRV::GroupBuiltin *GroupBuiltin =
       SPIRV::lookupGroupBuiltin(Builtin->Name);
+
+  if (Call->isSpirvOp())
+    return buildOpFromWrapper(MIRBuilder, GroupBuiltin->Opcode, Call,
+                              GR->getSPIRVTypeID(Call->ReturnType));
+
   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
   Register Arg0;
   if (GroupBuiltin->HasBoolArg) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
index 2edd2992425bd..d93756cc67c9c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
@@ -694,9 +694,17 @@ multiclass DemangledGroupBuiltin<string name, int level /* OnlyWork/OnlySub/...
   }
 }
 
+multiclass DemangledGroupBuiltinWrapper<string name, bits<8> minNumArgs, bits<8> maxNumArgs, Op operation> {
+    def : DemangledBuiltin<name, OpenCL_std, Group, minNumArgs, maxNumArgs>;
+    def : GroupBuiltin<name, operation>;
+}
+
 defm : DemangledGroupBuiltin<"group_all", WorkOrSub, OpGroupAll>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupAll", 2, 2, OpGroupAll>;
 defm : DemangledGroupBuiltin<"group_any", WorkOrSub, OpGroupAny>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupAny", 2, 2, OpGroupAny>;
 defm : DemangledGroupBuiltin<"group_broadcast", WorkOrSub, OpGroupBroadcast>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupBroadcast", 3, 3, OpGroupBroadcast>;
 defm : DemangledGroupBuiltin<"group_non_uniform_broadcast", OnlySub, OpGroupNonUniformBroadcast>;
 defm : DemangledGroupBuiltin<"group_broadcast_first", OnlySub, OpGroupNonUniformBroadcastFirst>;
 
@@ -731,41 +739,49 @@ defm : DemangledGroupBuiltin<"group_scan_inclusive_adds", WorkOrSub, OpGroupIAdd
 defm : DemangledGroupBuiltin<"group_reduce_addu", WorkOrSub, OpGroupIAdd>;
 defm : DemangledGroupBuiltin<"group_scan_exclusive_addu", WorkOrSub, OpGroupIAdd>;
 defm : DemangledGroupBuiltin<"group_scan_inclusive_addu", WorkOrSub, OpGroupIAdd>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupIAdd", 3, 3, OpGroupIAdd>;
 
 defm : DemangledGroupBuiltin<"group_fadd", WorkOrSub, OpGroupFAdd>;
 defm : DemangledGroupBuiltin<"group_reduce_addf", WorkOrSub, OpGroupFAdd>;
 defm : DemangledGroupBuiltin<"group_scan_exclusive_addf", WorkOrSub, OpGroupFAdd>;
 defm : DemangledGroupBuiltin<"group_scan_inclusive_addf", WorkOrSub, OpGroupFAdd>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupFAdd", 3, 3, OpGroupFAdd>;
 
 defm : DemangledGroupBuiltin<"group_fmin", WorkOrSub, OpGroupFMin>;
 defm : DemangledGroupBuiltin<"group_reduce_minf", WorkOrSub, OpGroupFMin>;
 defm : DemangledGroupBuiltin<"group_scan_exclusive_minf", WorkOrSub, OpGroupFMin>;
 defm : DemangledGroupBuiltin<"group_scan_inclusive_minf", WorkOrSub, OpGroupFMin>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupFMin", 3, 3, OpGroupFMin>;
 
 defm : DemangledGroupBuiltin<"group_umin", WorkOrSub, OpGroupUMin>;
 defm : DemangledGroupBuiltin<"group_reduce_minu", WorkOrSub, OpGroupUMin>;
 defm : DemangledGroupBuiltin<"group_scan_exclusive_minu", WorkOrSub, OpGroupUMin>;
 defm : DemangledGroupBuiltin<"group_scan_inclusive_minu", WorkOrSub, OpGroupUMin>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupUMin", 3, 3, OpGroupUMin>;
 
 defm : DemangledGroupBuiltin<"group_smin", WorkOrSub, OpGroupSMin>;
 defm : DemangledGroupBuiltin<"group_reduce_mins", WorkOrSub, OpGroupSMin>;
 defm : DemangledGroupBuiltin<"group_scan_exclusive_mins", WorkOrSub, OpGroupSMin>;
 defm : DemangledGroupBuiltin<"group_scan_inclusive_mins", WorkOrSub, OpGroupSMin>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupSMin", 3, 3, OpGroupSMin>;
 
 defm : DemangledGroupBuiltin<"group_fmax", WorkOrSub, OpGroupFMax>;
 defm : DemangledGroupBuiltin<"group_reduce_maxf", WorkOrSub, OpGroupFMax>;
 defm : DemangledGroupBuiltin<"group_scan_exclusive_maxf", WorkOrSub, OpGroupFMax>;
 defm : DemangledGroupBuiltin<"group_scan_inclusive_maxf", WorkOrSub, OpGroupFMax>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupFMax", 3, 3, OpGroupFMax>;
 
 defm : DemangledGroupBuiltin<"group_umax", WorkOrSub, OpGroupUMax>;
 defm : DemangledGroupBuiltin<"group_reduce_maxu", WorkOrSub, OpGroupUMax>;
 defm : DemangledGroupBuiltin<"group_scan_exclusive_maxu", WorkOrSub, OpGroupUMax>;
 defm : DemangledGroupBuiltin<"group_scan_inclusive_maxu", WorkOrSub, OpGroupUMax>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupUMax", 3, 3, OpGroupUMax>;
 
 defm : DemangledGroupBuiltin<"group_smax", WorkOrSub, OpGroupSMax>;
 defm : DemangledGroupBuiltin<"group_reduce_maxs", WorkOrSub, OpGroupSMax>;
 defm : DemangledGroupBuiltin<"group_scan_exclusive_maxs", WorkOrSub, OpGroupSMax>;
 defm : DemangledGroupBuiltin<"group_scan_inclusive_maxs", WorkOrSub, OpGroupSMax>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupSMax", 3, 3, OpGroupSMax>;
 
 // cl_khr_subgroup_non_uniform_arithmetic
 defm : DemangledGroupBuiltin<"group_non_uniform_iadd", WorkOrSub, OpGroupNonUniformIAdd>;
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/OpGroupAllAny.ll b/llvm/test/CodeGen/SPIRV/transcoding/OpGroupAllAny.ll
index 5f11ca2a474e9..39f9e0581d2f8 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/OpGroupAllAny.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/OpGroupAllAny.ll
@@ -1,18 +1,29 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
 ; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
 
 ; CHECK-SPIRV: OpCapability Groups
-; CHECK-SPIRV: %[[#BoolTypeID:]] = OpTypeBool
-; CHECK-SPIRV: %[[#ConstID:]] = OpConstantTrue %[[#BoolTypeID]]
-; CHECK-SPIRV: %[[#]] = OpGroupAll %[[#BoolTypeID]] %[[#]] %[[#ConstID]]
-; CHECK-SPIRV: %[[#]] = OpGroupAny %[[#BoolTypeID]] %[[#]] %[[#ConstID]]
+; CHECK-SPIRV-DAG: %[[#BoolTypeID:]] = OpTypeBool
+; CHECK-SPIRV-DAG: %[[#True:]] = OpConstantTrue %[[#BoolTypeID]]
+; CHECK-SPIRV-DAG: %[[#False:]] = OpConstantFalse %[[#BoolTypeID]]
+; CHECK-SPIRV: %[[#]] = OpGroupAll %[[#BoolTypeID]] %[[#]] %[[#True]]
+; CHECK-SPIRV: %[[#]] = OpGroupAny %[[#BoolTypeID]] %[[#]] %[[#True]]
+; CHECK-SPIRV: %[[#]] = OpGroupAll %[[#BoolTypeID]] %[[#]] %[[#True]]
+; CHECK-SPIRV: %[[#]] = OpGroupAny %[[#BoolTypeID]] %[[#]] %[[#False]]
 
 define spir_kernel void @test(i32 addrspace(1)* nocapture readnone %i) {
 entry:
   %call = tail call spir_func i32 @_Z14work_group_alli(i32 5)
   %call1 = tail call spir_func i32 @_Z14work_group_anyi(i32 5)
+  %call3 = tail call spir_func i32 @__spirv_GroupAll(i32 0, i1 1)
+  %call4 = tail call spir_func i32 @__spirv_GroupAny(i32 0, i1 0)
   ret void
 }
 
 declare spir_func i32 @_Z14work_group_alli(i32)
-
 declare spir_func i32 @_Z14work_group_anyi(i32)
+
+declare spir_func i1 @__spirv_GroupAll(i32, i1)
+declare spir_func i1 @__spirv_GroupAny(i32, i1)
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/group_ops.ll b/llvm/test/CodeGen/SPIRV/transcoding/group_ops.ll
index 2412f406a9c62..a0a8c64faedd7 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/group_ops.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/group_ops.ll
@@ -1,8 +1,12 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
 ; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
 ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
 
 ; CHECK-SPIRV-DAG: %[[#int:]] = OpTypeInt 32 0
 ; CHECK-SPIRV-DAG: %[[#float:]] = OpTypeFloat 32
+; CHECK-SPIRV-DAG: %[[#ScopeCrossWorkgroup:]] = OpConstant %[[#int]] 0
 ; CHECK-SPIRV-DAG: %[[#ScopeWorkgroup:]] = OpConstant %[[#int]] 2
 ; CHECK-SPIRV-DAG: %[[#ScopeSubgroup:]] = OpConstant %[[#int]] 3
 
@@ -247,7 +251,8 @@ entry:
 declare spir_func i32 @_Z21work_group_reduce_minj(i32 noundef) local_unnamed_addr
 
 ; CHECK-SPIRV: OpFunction
-; CHECK-SPIRV: %[[#]] = OpGroupBroadcast %[[#int]] %[[#ScopeWorkgroup]]
+; CHECK-SPIRV: %[[#]] = OpGroupBroadcast %[[#int]] %[[#ScopeWorkgroup]] %[[#BroadcastValue:]] %[[#BroadcastLocalId:]]
+; CHECK-SPIRV: %[[#]] = OpGroupBroadcast %[[#int]] %[[#ScopeCrossWorkgroup]] %[[#BroadcastValue]] %[[#BroadcastLocalId]]
 ; CHECK-SPIRV: OpFunctionEnd
 
 ;; kernel void testWorkGroupBroadcast(uint a, global size_t *id, global int *res) {
@@ -259,7 +264,9 @@ entry:
   %0 = load i32, i32 addrspace(1)* %id, align 4
   %call = call spir_func i32 @_Z20work_group_broadcastjj(i32 noundef %a, i32 noundef %0)
   store i32 %call, i32 addrspace(1)* %res, align 4
+  %call1 = call spir_func i32 @__spirv_GroupBroadcast(i32 0, i32 noundef %a, i32 noundef %0)
   ret void
 }
 
 declare spir_func i32 @_Z20work_group_broadcastjj(i32 noundef, i32 noundef) local_unnamed_addr
+declare spir_func i32 @__spirv_GroupBroadcast(i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr

>From 5fc6aa88843fffd4a318beffab4db9e2493093a8 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 12 Jun 2024 03:35:32 -0700
Subject: [PATCH 2/2] fix and add tests cases for instructions with a Group
 Operation

---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp       | 33 ++++++++++++++++---
 .../CodeGen/SPIRV/transcoding/group_ops.ll    | 33 +++++++++++++++++++
 2 files changed, 62 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index bfe2e01387279..cfed7e584e7bc 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -1016,11 +1016,36 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
   const SPIRV::GroupBuiltin *GroupBuiltin =
       SPIRV::lookupGroupBuiltin(Builtin->Name);
 
-  if (Call->isSpirvOp())
-    return buildOpFromWrapper(MIRBuilder, GroupBuiltin->Opcode, Call,
-                              GR->getSPIRVTypeID(Call->ReturnType));
-
   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
+  if (Call->isSpirvOp()) {
+    if (GroupBuiltin->NoGroupOperation)
+      return buildOpFromWrapper(MIRBuilder, GroupBuiltin->Opcode, Call,
+                                GR->getSPIRVTypeID(Call->ReturnType));
+
+    // Group Operation is a literal
+    Register GroupOpReg = Call->Arguments[1];
+    const MachineInstr *MI = getDefInstrMaybeConstant(GroupOpReg, MRI);
+    if (!MI || MI->getOpcode() != TargetOpcode::G_CONSTANT) {
+      std::string DiagMsg = std::string(Builtin->Name) +
+                            ": expect a constant value of Group Operation";
+      report_fatal_error(DiagMsg.c_str());
+    }
+    uint64_t GrpOp = MI->getOperand(1).getCImm()->getValue().getZExtValue();
+    Register ScopeReg = Call->Arguments[0];
+    if (!MRI->getRegClassOrNull(ScopeReg))
+      MRI->setRegClass(ScopeReg, &SPIRV::IDRegClass);
+    Register ValueReg = Call->Arguments[2];
+    if (!MRI->getRegClassOrNull(ValueReg))
+      MRI->setRegClass(ValueReg, &SPIRV::IDRegClass);
+    MIRBuilder.buildInstr(GroupBuiltin->Opcode)
+        .addDef(Call->ReturnRegister)
+        .addUse(GR->getSPIRVTypeID(Call->ReturnType))
+        .addUse(ScopeReg)
+        .addImm(GrpOp)
+        .addUse(ValueReg);
+    return true;
+  }
+
   Register Arg0;
   if (GroupBuiltin->HasBoolArg) {
     Register ConstRegister = Call->Arguments[0];
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/group_ops.ll b/llvm/test/CodeGen/SPIRV/transcoding/group_ops.ll
index a0a8c64faedd7..a40a61a8dc7d7 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/group_ops.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/group_ops.ll
@@ -270,3 +270,36 @@ entry:
 
 declare spir_func i32 @_Z20work_group_broadcastjj(i32 noundef, i32 noundef) local_unnamed_addr
 declare spir_func i32 @__spirv_GroupBroadcast(i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr
+
+; CHECK-SPIRV: OpFunction
+; CHECK-SPIRV: %[[#]] = OpGroupFAdd %[[#float]] %[[#ScopeCrossWorkgroup]] Reduce %[[#FValue:]]
+; CHECK-SPIRV: %[[#]] = OpGroupFMin %[[#float]] %[[#ScopeWorkgroup]] InclusiveScan %[[#FValue]]
+; CHECK-SPIRV: %[[#]] = OpGroupFMax %[[#float]] %[[#ScopeSubgroup]] ExclusiveScan %[[#FValue]]
+; CHECK-SPIRV: %[[#]] = OpGroupIAdd %[[#int]] %[[#ScopeCrossWorkgroup]] Reduce %[[#IValue:]]
+; CHECK-SPIRV: %[[#]] = OpGroupUMin %[[#int]] %[[#ScopeWorkgroup]] InclusiveScan %[[#IValue]]
+; CHECK-SPIRV: %[[#]] = OpGroupSMin %[[#int]] %[[#ScopeSubgroup]] ExclusiveScan %[[#IValue]]
+; CHECK-SPIRV: %[[#]] = OpGroupUMax %[[#int]] %[[#ScopeCrossWorkgroup]] Reduce %[[#IValue]]
+; CHECK-SPIRV: %[[#]] = OpGroupSMax %[[#int]] %[[#ScopeWorkgroup]] InclusiveScan %[[#IValue]]
+; CHECK-SPIRV: OpFunctionEnd
+
+define spir_kernel void @foo(float %a, i32 %b) {
+entry:
+  %f1 = call spir_func float @__spirv_GroupFAdd(i32 0, i32 0, float %a)
+  %f2 = call spir_func float @__spirv_GroupFMin(i32 2, i32 1, float %a)
+  %f3 = call spir_func float @__spirv_GroupFMax(i32 3, i32 2, float %a)
+  %i1 = call spir_func i32 @__spirv_GroupIAdd(i32 0, i32 0, i32 %b)
+  %i2 = call spir_func i32 @__spirv_GroupUMin(i32 2, i32 1, i32 %b)
+  %i3 = call spir_func i32 @__spirv_GroupSMin(i32 3, i32 2, i32 %b)
+  %i4 = call spir_func i32 @__spirv_GroupUMax(i32 0, i32 0, i32 %b)
+  %i5 = call spir_func i32 @__spirv_GroupSMax(i32 2, i32 1, i32 %b)
+  ret void
+}
+
+declare spir_func float @__spirv_GroupFAdd(i32, i32, float)
+declare spir_func float @__spirv_GroupFMin(i32, i32, float)
+declare spir_func float @__spirv_GroupFMax(i32, i32, float)
+declare spir_func i32 @__spirv_GroupIAdd(i32, i32, i32)
+declare spir_func i32 @__spirv_GroupUMin(i32, i32, i32)
+declare spir_func i32 @__spirv_GroupSMin(i32, i32, i32)
+declare spir_func i32 @__spirv_GroupUMax(i32, i32, i32)
+declare spir_func i32 @__spirv_GroupSMax(i32, i32, i32)



More information about the llvm-commits mailing list