[llvm] [SPIR-V] Implement insertion of 'Group and Subgroup Instructions' using builtin functions (PR #95176)
Vyacheslav Levytskyy via llvm-commits
llvm-commits at lists.llvm.org
Wed Jun 12 03:35:43 PDT 2024
https://github.com/VyacheslavLevytskyy updated https://github.com/llvm/llvm-project/pull/95176
>From 1de94d74e4c543904066b86136f38d9d10e41e5d Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Tue, 11 Jun 2024 14:30:39 -0700
Subject: [PATCH 1/2] implement insertion of 'Group and Subgroup Instructions'
using builtin functions-wrappers
---
llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp | 5 +++++
llvm/lib/Target/SPIRV/SPIRVBuiltins.td | 16 ++++++++++++++
.../SPIRV/transcoding/OpGroupAllAny.ll | 21 ++++++++++++++-----
.../CodeGen/SPIRV/transcoding/group_ops.ll | 9 +++++++-
4 files changed, 45 insertions(+), 6 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 6bb3e215240a8..bfe2e01387279 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -1015,6 +1015,11 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
const SPIRV::GroupBuiltin *GroupBuiltin =
SPIRV::lookupGroupBuiltin(Builtin->Name);
+
+ if (Call->isSpirvOp())
+ return buildOpFromWrapper(MIRBuilder, GroupBuiltin->Opcode, Call,
+ GR->getSPIRVTypeID(Call->ReturnType));
+
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
Register Arg0;
if (GroupBuiltin->HasBoolArg) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
index 2edd2992425bd..d93756cc67c9c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
@@ -694,9 +694,17 @@ multiclass DemangledGroupBuiltin<string name, int level /* OnlyWork/OnlySub/...
}
}
+multiclass DemangledGroupBuiltinWrapper<string name, bits<8> minNumArgs, bits<8> maxNumArgs, Op operation> {
+ def : DemangledBuiltin<name, OpenCL_std, Group, minNumArgs, maxNumArgs>;
+ def : GroupBuiltin<name, operation>;
+}
+
defm : DemangledGroupBuiltin<"group_all", WorkOrSub, OpGroupAll>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupAll", 2, 2, OpGroupAll>;
defm : DemangledGroupBuiltin<"group_any", WorkOrSub, OpGroupAny>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupAny", 2, 2, OpGroupAny>;
defm : DemangledGroupBuiltin<"group_broadcast", WorkOrSub, OpGroupBroadcast>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupBroadcast", 3, 3, OpGroupBroadcast>;
defm : DemangledGroupBuiltin<"group_non_uniform_broadcast", OnlySub, OpGroupNonUniformBroadcast>;
defm : DemangledGroupBuiltin<"group_broadcast_first", OnlySub, OpGroupNonUniformBroadcastFirst>;
@@ -731,41 +739,49 @@ defm : DemangledGroupBuiltin<"group_scan_inclusive_adds", WorkOrSub, OpGroupIAdd
defm : DemangledGroupBuiltin<"group_reduce_addu", WorkOrSub, OpGroupIAdd>;
defm : DemangledGroupBuiltin<"group_scan_exclusive_addu", WorkOrSub, OpGroupIAdd>;
defm : DemangledGroupBuiltin<"group_scan_inclusive_addu", WorkOrSub, OpGroupIAdd>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupIAdd", 3, 3, OpGroupIAdd>;
defm : DemangledGroupBuiltin<"group_fadd", WorkOrSub, OpGroupFAdd>;
defm : DemangledGroupBuiltin<"group_reduce_addf", WorkOrSub, OpGroupFAdd>;
defm : DemangledGroupBuiltin<"group_scan_exclusive_addf", WorkOrSub, OpGroupFAdd>;
defm : DemangledGroupBuiltin<"group_scan_inclusive_addf", WorkOrSub, OpGroupFAdd>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupFAdd", 3, 3, OpGroupFAdd>;
defm : DemangledGroupBuiltin<"group_fmin", WorkOrSub, OpGroupFMin>;
defm : DemangledGroupBuiltin<"group_reduce_minf", WorkOrSub, OpGroupFMin>;
defm : DemangledGroupBuiltin<"group_scan_exclusive_minf", WorkOrSub, OpGroupFMin>;
defm : DemangledGroupBuiltin<"group_scan_inclusive_minf", WorkOrSub, OpGroupFMin>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupFMin", 3, 3, OpGroupFMin>;
defm : DemangledGroupBuiltin<"group_umin", WorkOrSub, OpGroupUMin>;
defm : DemangledGroupBuiltin<"group_reduce_minu", WorkOrSub, OpGroupUMin>;
defm : DemangledGroupBuiltin<"group_scan_exclusive_minu", WorkOrSub, OpGroupUMin>;
defm : DemangledGroupBuiltin<"group_scan_inclusive_minu", WorkOrSub, OpGroupUMin>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupUMin", 3, 3, OpGroupUMin>;
defm : DemangledGroupBuiltin<"group_smin", WorkOrSub, OpGroupSMin>;
defm : DemangledGroupBuiltin<"group_reduce_mins", WorkOrSub, OpGroupSMin>;
defm : DemangledGroupBuiltin<"group_scan_exclusive_mins", WorkOrSub, OpGroupSMin>;
defm : DemangledGroupBuiltin<"group_scan_inclusive_mins", WorkOrSub, OpGroupSMin>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupSMin", 3, 3, OpGroupSMin>;
defm : DemangledGroupBuiltin<"group_fmax", WorkOrSub, OpGroupFMax>;
defm : DemangledGroupBuiltin<"group_reduce_maxf", WorkOrSub, OpGroupFMax>;
defm : DemangledGroupBuiltin<"group_scan_exclusive_maxf", WorkOrSub, OpGroupFMax>;
defm : DemangledGroupBuiltin<"group_scan_inclusive_maxf", WorkOrSub, OpGroupFMax>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupFMax", 3, 3, OpGroupFMax>;
defm : DemangledGroupBuiltin<"group_umax", WorkOrSub, OpGroupUMax>;
defm : DemangledGroupBuiltin<"group_reduce_maxu", WorkOrSub, OpGroupUMax>;
defm : DemangledGroupBuiltin<"group_scan_exclusive_maxu", WorkOrSub, OpGroupUMax>;
defm : DemangledGroupBuiltin<"group_scan_inclusive_maxu", WorkOrSub, OpGroupUMax>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupUMax", 3, 3, OpGroupUMax>;
defm : DemangledGroupBuiltin<"group_smax", WorkOrSub, OpGroupSMax>;
defm : DemangledGroupBuiltin<"group_reduce_maxs", WorkOrSub, OpGroupSMax>;
defm : DemangledGroupBuiltin<"group_scan_exclusive_maxs", WorkOrSub, OpGroupSMax>;
defm : DemangledGroupBuiltin<"group_scan_inclusive_maxs", WorkOrSub, OpGroupSMax>;
+defm : DemangledGroupBuiltinWrapper<"__spirv_GroupSMax", 3, 3, OpGroupSMax>;
// cl_khr_subgroup_non_uniform_arithmetic
defm : DemangledGroupBuiltin<"group_non_uniform_iadd", WorkOrSub, OpGroupNonUniformIAdd>;
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/OpGroupAllAny.ll b/llvm/test/CodeGen/SPIRV/transcoding/OpGroupAllAny.ll
index 5f11ca2a474e9..39f9e0581d2f8 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/OpGroupAllAny.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/OpGroupAllAny.ll
@@ -1,18 +1,29 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
; CHECK-SPIRV: OpCapability Groups
-; CHECK-SPIRV: %[[#BoolTypeID:]] = OpTypeBool
-; CHECK-SPIRV: %[[#ConstID:]] = OpConstantTrue %[[#BoolTypeID]]
-; CHECK-SPIRV: %[[#]] = OpGroupAll %[[#BoolTypeID]] %[[#]] %[[#ConstID]]
-; CHECK-SPIRV: %[[#]] = OpGroupAny %[[#BoolTypeID]] %[[#]] %[[#ConstID]]
+; CHECK-SPIRV-DAG: %[[#BoolTypeID:]] = OpTypeBool
+; CHECK-SPIRV-DAG: %[[#True:]] = OpConstantTrue %[[#BoolTypeID]]
+; CHECK-SPIRV-DAG: %[[#False:]] = OpConstantFalse %[[#BoolTypeID]]
+; CHECK-SPIRV: %[[#]] = OpGroupAll %[[#BoolTypeID]] %[[#]] %[[#True]]
+; CHECK-SPIRV: %[[#]] = OpGroupAny %[[#BoolTypeID]] %[[#]] %[[#True]]
+; CHECK-SPIRV: %[[#]] = OpGroupAll %[[#BoolTypeID]] %[[#]] %[[#True]]
+; CHECK-SPIRV: %[[#]] = OpGroupAny %[[#BoolTypeID]] %[[#]] %[[#False]]
define spir_kernel void @test(i32 addrspace(1)* nocapture readnone %i) {
entry:
%call = tail call spir_func i32 @_Z14work_group_alli(i32 5)
%call1 = tail call spir_func i32 @_Z14work_group_anyi(i32 5)
+ %call3 = tail call spir_func i32 @__spirv_GroupAll(i32 0, i1 1)
+ %call4 = tail call spir_func i32 @__spirv_GroupAny(i32 0, i1 0)
ret void
}
declare spir_func i32 @_Z14work_group_alli(i32)
-
declare spir_func i32 @_Z14work_group_anyi(i32)
+
+declare spir_func i1 @__spirv_GroupAll(i32, i1)
+declare spir_func i1 @__spirv_GroupAny(i32, i1)
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/group_ops.ll b/llvm/test/CodeGen/SPIRV/transcoding/group_ops.ll
index 2412f406a9c62..a0a8c64faedd7 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/group_ops.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/group_ops.ll
@@ -1,8 +1,12 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
; CHECK-SPIRV-DAG: %[[#int:]] = OpTypeInt 32 0
; CHECK-SPIRV-DAG: %[[#float:]] = OpTypeFloat 32
+; CHECK-SPIRV-DAG: %[[#ScopeCrossWorkgroup:]] = OpConstant %[[#int]] 0
; CHECK-SPIRV-DAG: %[[#ScopeWorkgroup:]] = OpConstant %[[#int]] 2
; CHECK-SPIRV-DAG: %[[#ScopeSubgroup:]] = OpConstant %[[#int]] 3
@@ -247,7 +251,8 @@ entry:
declare spir_func i32 @_Z21work_group_reduce_minj(i32 noundef) local_unnamed_addr
; CHECK-SPIRV: OpFunction
-; CHECK-SPIRV: %[[#]] = OpGroupBroadcast %[[#int]] %[[#ScopeWorkgroup]]
+; CHECK-SPIRV: %[[#]] = OpGroupBroadcast %[[#int]] %[[#ScopeWorkgroup]] %[[#BroadcastValue:]] %[[#BroadcastLocalId:]]
+; CHECK-SPIRV: %[[#]] = OpGroupBroadcast %[[#int]] %[[#ScopeCrossWorkgroup]] %[[#BroadcastValue]] %[[#BroadcastLocalId]]
; CHECK-SPIRV: OpFunctionEnd
;; kernel void testWorkGroupBroadcast(uint a, global size_t *id, global int *res) {
@@ -259,7 +264,9 @@ entry:
%0 = load i32, i32 addrspace(1)* %id, align 4
%call = call spir_func i32 @_Z20work_group_broadcastjj(i32 noundef %a, i32 noundef %0)
store i32 %call, i32 addrspace(1)* %res, align 4
+ %call1 = call spir_func i32 @__spirv_GroupBroadcast(i32 0, i32 noundef %a, i32 noundef %0)
ret void
}
declare spir_func i32 @_Z20work_group_broadcastjj(i32 noundef, i32 noundef) local_unnamed_addr
+declare spir_func i32 @__spirv_GroupBroadcast(i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr
>From 5fc6aa88843fffd4a318beffab4db9e2493093a8 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 12 Jun 2024 03:35:32 -0700
Subject: [PATCH 2/2] fix and add tests cases for instructions with a Group
Operation
---
llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp | 33 ++++++++++++++++---
.../CodeGen/SPIRV/transcoding/group_ops.ll | 33 +++++++++++++++++++
2 files changed, 62 insertions(+), 4 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index bfe2e01387279..cfed7e584e7bc 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -1016,11 +1016,36 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
const SPIRV::GroupBuiltin *GroupBuiltin =
SPIRV::lookupGroupBuiltin(Builtin->Name);
- if (Call->isSpirvOp())
- return buildOpFromWrapper(MIRBuilder, GroupBuiltin->Opcode, Call,
- GR->getSPIRVTypeID(Call->ReturnType));
-
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
+ if (Call->isSpirvOp()) {
+ if (GroupBuiltin->NoGroupOperation)
+ return buildOpFromWrapper(MIRBuilder, GroupBuiltin->Opcode, Call,
+ GR->getSPIRVTypeID(Call->ReturnType));
+
+ // Group Operation is a literal
+ Register GroupOpReg = Call->Arguments[1];
+ const MachineInstr *MI = getDefInstrMaybeConstant(GroupOpReg, MRI);
+ if (!MI || MI->getOpcode() != TargetOpcode::G_CONSTANT) {
+ std::string DiagMsg = std::string(Builtin->Name) +
+ ": expect a constant value of Group Operation";
+ report_fatal_error(DiagMsg.c_str());
+ }
+ uint64_t GrpOp = MI->getOperand(1).getCImm()->getValue().getZExtValue();
+ Register ScopeReg = Call->Arguments[0];
+ if (!MRI->getRegClassOrNull(ScopeReg))
+ MRI->setRegClass(ScopeReg, &SPIRV::IDRegClass);
+ Register ValueReg = Call->Arguments[2];
+ if (!MRI->getRegClassOrNull(ValueReg))
+ MRI->setRegClass(ValueReg, &SPIRV::IDRegClass);
+ MIRBuilder.buildInstr(GroupBuiltin->Opcode)
+ .addDef(Call->ReturnRegister)
+ .addUse(GR->getSPIRVTypeID(Call->ReturnType))
+ .addUse(ScopeReg)
+ .addImm(GrpOp)
+ .addUse(ValueReg);
+ return true;
+ }
+
Register Arg0;
if (GroupBuiltin->HasBoolArg) {
Register ConstRegister = Call->Arguments[0];
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/group_ops.ll b/llvm/test/CodeGen/SPIRV/transcoding/group_ops.ll
index a0a8c64faedd7..a40a61a8dc7d7 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/group_ops.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/group_ops.ll
@@ -270,3 +270,36 @@ entry:
declare spir_func i32 @_Z20work_group_broadcastjj(i32 noundef, i32 noundef) local_unnamed_addr
declare spir_func i32 @__spirv_GroupBroadcast(i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr
+
+; CHECK-SPIRV: OpFunction
+; CHECK-SPIRV: %[[#]] = OpGroupFAdd %[[#float]] %[[#ScopeCrossWorkgroup]] Reduce %[[#FValue:]]
+; CHECK-SPIRV: %[[#]] = OpGroupFMin %[[#float]] %[[#ScopeWorkgroup]] InclusiveScan %[[#FValue]]
+; CHECK-SPIRV: %[[#]] = OpGroupFMax %[[#float]] %[[#ScopeSubgroup]] ExclusiveScan %[[#FValue]]
+; CHECK-SPIRV: %[[#]] = OpGroupIAdd %[[#int]] %[[#ScopeCrossWorkgroup]] Reduce %[[#IValue:]]
+; CHECK-SPIRV: %[[#]] = OpGroupUMin %[[#int]] %[[#ScopeWorkgroup]] InclusiveScan %[[#IValue]]
+; CHECK-SPIRV: %[[#]] = OpGroupSMin %[[#int]] %[[#ScopeSubgroup]] ExclusiveScan %[[#IValue]]
+; CHECK-SPIRV: %[[#]] = OpGroupUMax %[[#int]] %[[#ScopeCrossWorkgroup]] Reduce %[[#IValue]]
+; CHECK-SPIRV: %[[#]] = OpGroupSMax %[[#int]] %[[#ScopeWorkgroup]] InclusiveScan %[[#IValue]]
+; CHECK-SPIRV: OpFunctionEnd
+
+define spir_kernel void @foo(float %a, i32 %b) {
+entry:
+ %f1 = call spir_func float @__spirv_GroupFAdd(i32 0, i32 0, float %a)
+ %f2 = call spir_func float @__spirv_GroupFMin(i32 2, i32 1, float %a)
+ %f3 = call spir_func float @__spirv_GroupFMax(i32 3, i32 2, float %a)
+ %i1 = call spir_func i32 @__spirv_GroupIAdd(i32 0, i32 0, i32 %b)
+ %i2 = call spir_func i32 @__spirv_GroupUMin(i32 2, i32 1, i32 %b)
+ %i3 = call spir_func i32 @__spirv_GroupSMin(i32 3, i32 2, i32 %b)
+ %i4 = call spir_func i32 @__spirv_GroupUMax(i32 0, i32 0, i32 %b)
+ %i5 = call spir_func i32 @__spirv_GroupSMax(i32 2, i32 1, i32 %b)
+ ret void
+}
+
+declare spir_func float @__spirv_GroupFAdd(i32, i32, float)
+declare spir_func float @__spirv_GroupFMin(i32, i32, float)
+declare spir_func float @__spirv_GroupFMax(i32, i32, float)
+declare spir_func i32 @__spirv_GroupIAdd(i32, i32, i32)
+declare spir_func i32 @__spirv_GroupUMin(i32, i32, i32)
+declare spir_func i32 @__spirv_GroupSMin(i32, i32, i32)
+declare spir_func i32 @__spirv_GroupUMax(i32, i32, i32)
+declare spir_func i32 @__spirv_GroupSMax(i32, i32, i32)
More information about the llvm-commits
mailing list